diff --git a/flower/api/tasks.py b/flower/api/tasks.py index 730c290e4..d154224a8 100644 --- a/flower/api/tasks.py +++ b/flower/api/tasks.py @@ -493,6 +493,8 @@ def get(self): :query state: filter tasks by state :query received_start: filter tasks by received date (must be greater than) format %Y-%m-%d %H:%M :query received_end: filter tasks by received date (must be less than) format %Y-%m-%d %H:%M +:query only_fields: returns only selected fields for tasks (comma-separated) +:query except_fields: returns all but selected fields for tasks (comma-separated) :reqheader Authorization: optional OAuth token to authenticate :statuscode 200: no error :statuscode 401: unauthorized request @@ -507,6 +509,8 @@ def get(self): received_end = self.get_argument('received_end', None) sort_by = self.get_argument('sort_by', None) search = self.get_argument('search', None) + only_fields = self.get_argument('only_fields', None) + except_fields = self.get_argument('except_fields', None) limit = limit and int(limit) offset = max(offset, 0) @@ -522,7 +526,7 @@ def get(self): received_end=received_end, search=search ): - task = tasks.as_dict(task) + task = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields) worker = task.pop('worker', None) if worker is not None: task['worker'] = worker.hostname @@ -621,18 +625,22 @@ def get(self, taskid): "worker": "celery@worker1" } +:query only_fields: returns only selected fields for task (comma-separated) +:query except_fields: returns all but selected fields for task (comma-separated) :reqheader Authorization: optional OAuth token to authenticate :statuscode 200: no error :statuscode 401: unauthorized request :statuscode 404: unknown task """ + only_fields = self.get_argument('only_fields', None) + except_fields = self.get_argument('except_fields', None) task = tasks.get_task_by_id(self.application.events, taskid) if not task: raise HTTPError(404, f"Unknown task '{taskid}'") - response = task.as_dict() - if task.worker is not None: + response = tasks.as_dict(task, only_fields=only_fields, except_fields=except_fields) + if task.worker is not None and 'worker' in response: response['worker'] = task.worker.hostname self.write(response) diff --git a/flower/utils/tasks.py b/flower/utils/tasks.py index 4abcd82f6..3a0060daa 100644 --- a/flower/utils/tasks.py +++ b/flower/utils/tasks.py @@ -66,5 +66,34 @@ def get_task_by_id(events, task_id): return events.state.tasks.get(task_id) -def as_dict(task): - return task.as_dict() +def filter_dict(task_dict, only_fields=None, except_fields=None): + """ + Filter a dictionary based on only_fields or except_fields parameters. + + Args: + task_dict (dict): The dictionary to filter + only_fields (str or list): Fields to include (excludes all others) + except_fields (str or list): Fields to exclude (includes all others) + + Returns: + dict: The filtered dictionary + """ + if only_fields: + # Convert comma-separated string to list if necessary + if isinstance(only_fields, str): + only_fields = [field.strip() for field in only_fields.split(',')] + # Keep only the specified fields + return {k: v for k, v in task_dict.items() if k in only_fields} + elif except_fields: + # Convert comma-separated string to list if necessary + if isinstance(except_fields, str): + except_fields = [field.strip() for field in except_fields.split(',')] + # Remove the specified fields + return {k: v for k, v in task_dict.items() if k not in except_fields} + + return task_dict + + +def as_dict(task, only_fields=None, except_fields=None): + task_dict = task.as_dict() + return filter_dict(task_dict, only_fields, except_fields) diff --git a/tests/unit/api/test_tasks.py b/tests/unit/api/test_tasks.py index 551957d7e..94c12ec53 100644 --- a/tests/unit/api/test_tasks.py +++ b/tests/unit/api/test_tasks.py @@ -9,6 +9,7 @@ from celery.result import AsyncResult from flower.events import EventsState +from flower.utils.tasks import filter_dict from tests.unit.utils import task_succeeded_events from . import BaseApiTestCase @@ -93,7 +94,28 @@ class MockTasks: @staticmethod def get_task_by_id(events, task_id): from celery.events.state import Task - return Task() + task = Task() + # Set some test data on the task + task.name = 'test_task' + task.state = 'SUCCESS' + return task + + @staticmethod + def as_dict(task, only_fields=None, except_fields=None): + # Create a mock dictionary with test data + task_dict = { + 'name': task.name, + 'state': task.state, + 'worker': 'test_worker', + 'received': 1234567890, + 'started': 1234567891, + 'succeeded': 1234567892, + 'timestamp': 1234567892, + 'runtime': 2.0 + } + + # Use the filter_dict function to handle field filtering + return filter_dict(task_dict, only_fields, except_fields) class TaskTests(BaseApiTestCase): @@ -106,7 +128,68 @@ def get_app(self, capp=None): @patch('flower.api.tasks.tasks', new=MockTasks) def test_task_info(self): - self.get('/api/task/info/123') + # Make the request + r = self.get('/api/task/info/123') + + # Parse the response + task = json.loads(r.body.decode("utf-8")) + + # Assert the response status code + self.assertEqual(200, r.code) + + # Assert the task data + self.assertEqual('test_task', task['name']) + self.assertEqual('SUCCESS', task['state']) + self.assertEqual('test_worker', task['worker']) + self.assertEqual(1234567890, task['received']) + self.assertEqual(1234567891, task['started']) + self.assertEqual(1234567892, task['succeeded']) + self.assertEqual(1234567892, task['timestamp']) + self.assertEqual(2.0, task['runtime']) + + def test_task_info_field_selection(self): + state = EventsState() + state.get_or_create_worker('worker1') + events = [Event('worker-online', hostname='worker1')] + events += task_succeeded_events(worker='worker1', name='task1', + id='123') + + for i, e in enumerate(events): + e['clock'] = i + e['local_received'] = time.time() + state.event(e) + self.app.events.state = state + + # Test only_fields parameter + params = dict(only_fields='name,state') + + r = self.get('/api/task/info/123?' + '&'.join( + map(lambda x: '%s=%s' % x, params.items()))) + + task = json.loads(r.body.decode("utf-8")) + + self.assertEqual(200, r.code) + # Check that only the specified fields are returned + self.assertEqual(2, len(task)) + self.assertIn('name', task) + self.assertIn('state', task) + self.assertNotIn('worker', task) + self.assertNotIn('received', task) + + # Test except_fields parameter + params = dict(except_fields='worker,received') + + r = self.get('/api/task/info/123?' + '&'.join( + map(lambda x: '%s=%s' % x, params.items()))) + + task = json.loads(r.body.decode("utf-8")) + + self.assertEqual(200, r.code) + # Check that the specified fields are not returned + self.assertIn('name', task) + self.assertIn('state', task) + self.assertNotIn('worker', task) + self.assertNotIn('received', task) def test_tasks_pagination(self): state = EventsState() @@ -216,3 +299,38 @@ def test_tasks_pagination(self): self.assertEqual(1, len(table)) firstFetchedTaskName = table[list(table)[0]]['name'] self.assertEqual("task1", firstFetchedTaskName) + + # Test only_fields parameter + params = dict(limit=4, offset=0, sort_by='name', only_fields='name,state') + + r = self.get('/api/tasks?' + '&'.join( + map(lambda x: '%s=%s' % x, params.items()))) + + table = json.loads(r.body.decode("utf-8"), object_pairs_hook=OrderedDict) + + self.assertEqual(200, r.code) + self.assertEqual(4, len(table)) + # Check that only the specified fields are returned + task = table[list(table)[0]] + self.assertEqual(2, len(task)) + self.assertIn('name', task) + self.assertIn('state', task) + self.assertNotIn('worker', task) + self.assertNotIn('received', task) + + # Test except_fields parameter + params = dict(limit=4, offset=0, sort_by='name', except_fields='worker,received') + + r = self.get('/api/tasks?' + '&'.join( + map(lambda x: '%s=%s' % x, params.items()))) + + table = json.loads(r.body.decode("utf-8"), object_pairs_hook=OrderedDict) + + self.assertEqual(200, r.code) + self.assertEqual(4, len(table)) + # Check that the specified fields are not returned + task = table[list(table)[0]] + self.assertIn('name', task) + self.assertIn('state', task) + self.assertNotIn('worker', task) + self.assertNotIn('received', task)