Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 16 additions & 6 deletions esrally/driver/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import math
import multiprocessing
import queue
import sys
import threading
import time
from dataclasses import dataclass
Expand Down Expand Up @@ -78,8 +79,9 @@ class Bootstrap:
Prompts loading of track code on new actors
"""

def __init__(self, cfg):
def __init__(self, cfg, worker_id=None):
self.config = cfg
self.worker_id = worker_id


class PrepareTrack:
Expand Down Expand Up @@ -304,8 +306,10 @@ def receiveMsg_WakeupMessage(self, msg, sender):
self.driver.update_progress_message()
self.wakeupAfter(datetime.timedelta(seconds=DriverActor.WAKEUP_INTERVAL_SECONDS))

def create_client(self, host, cfg):
return self.createActor(Worker, targetActorRequirements=self._requirements(host))
def create_client(self, host, cfg, worker_id):
worker = self.createActor(Worker, targetActorRequirements=self._requirements(host))
self.send(worker, Bootstrap(cfg, worker_id))
return worker

def start_worker(self, driver, worker_id, cfg, track, allocations, client_contexts=None):
self.send(driver, StartWorker(worker_id, cfg, track, allocations, client_contexts))
Expand Down Expand Up @@ -768,7 +772,7 @@ def start_benchmark(self):
# don't assign workers without any clients
if len(clients) > 0:
self.logger.debug("Allocating worker [%d] on [%s] with [%d] clients.", worker_id, host, len(clients))
worker = self.target.create_client(host, self.config)
worker = self.target.create_client(host, self.config, worker_id)

client_allocations = ClientAllocations()
worker_client_contexts = {}
Expand Down Expand Up @@ -1211,11 +1215,17 @@ def __init__(self):
self.sample_queue_size = None

@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_StartWorker(self, msg, sender):
self.logger.info("Worker[%d] is about to start.", msg.worker_id)
def receiveMsg_Bootstrap(self, msg, sender):
self.driver_actor = sender
self.worker_id = msg.worker_id
# load node-specific config to have correct paths available
self.config = load_local_config(msg.config)
load_track(self.config, install_dependencies=False)
self.logger.debug("Worker[%d] has Python load path %s after bootstrap.", self.worker_id, sys.path)

@actor.no_retry("worker") # pylint: disable=no-value-for-parameter
def receiveMsg_StartWorker(self, msg, sender):
self.logger.info("Worker[%d] is about to start.", msg.worker_id)
self.on_error = self.config.opts("driver", "on.error")
self.sample_queue_size = int(self.config.opts("reporting", "sample.queue.size", mandatory=False, default_value=1 << 20))
self.track = msg.track
Expand Down
9 changes: 8 additions & 1 deletion esrally/utils/repo.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,4 +126,11 @@ def checkout(self, revision):
git.checkout(self.repo_dir, branch=revision)

def correct_revision(self, revision):
return git.head_revision(self.repo_dir) == revision
if git.is_branch(self.repo_dir, revision):
current_branch = git.current_branch(self.repo_dir)
self.logger.info("Checking current branch [%s] is equal to specified branch [%s].", current_branch, revision)
return current_branch == revision

current_revision = git.head_revision(self.repo_dir)
self.logger.info("Checking current revision [%s] is equal to specified revision [%s].", current_revision, revision)
return current_revision == revision
18 changes: 10 additions & 8 deletions tests/driver/driver_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,7 @@ def create_test_driver_target(self):

@mock.patch("esrally.utils.net.resolve")
def test_start_benchmark_and_prepare_track(self, resolve):
worker_id = [0, 1, 2, 3]
# override load driver host
self.cfg.add(config.Scope.applicationOverride, "driver", "load_driver_hosts", ["10.5.5.1", "10.5.5.2"])
resolve.side_effect = ["10.5.5.1", "10.5.5.2"]
Expand All @@ -179,10 +180,10 @@ def test_start_benchmark_and_prepare_track(self, resolve):

target.create_client.assert_has_calls(
calls=[
mock.call("10.5.5.1", d.config),
mock.call("10.5.5.1", d.config),
mock.call("10.5.5.2", d.config),
mock.call("10.5.5.2", d.config),
mock.call("10.5.5.1", d.config, worker_id[0]),
mock.call("10.5.5.1", d.config, worker_id[1]),
mock.call("10.5.5.2", d.config, worker_id[2]),
mock.call("10.5.5.2", d.config, worker_id[3]),
]
)

Expand Down Expand Up @@ -216,6 +217,7 @@ def test_prepare_serverless_benchmark(self, mock_method):
}

def test_assign_drivers_round_robin(self):
worker_id = [0, 1, 2, 3]
target = self.create_test_driver_target()
d = driver.Driver(target, self.cfg, es_client_factory_class=self.StaticClientFactory)

Expand All @@ -227,10 +229,10 @@ def test_assign_drivers_round_robin(self):

target.create_client.assert_has_calls(
calls=[
mock.call("localhost", d.config),
mock.call("localhost", d.config),
mock.call("localhost", d.config),
mock.call("localhost", d.config),
mock.call("localhost", d.config, worker_id[0]),
mock.call("localhost", d.config, worker_id[1]),
mock.call("localhost", d.config, worker_id[2]),
mock.call("localhost", d.config, worker_id[3]),
]
)

Expand Down