Description
Provide the ability to limit task concurrency _per worker_.
Use case / motivation
My use case is that I have a particularly heavy task - one that uses lots of RAM & GPU - where if too many instances of that task are running on the same machine at a time, it'll crash. My ideal situation is to have a flag on the operator, something like task_concurrency_per_worker, that'll guarantee at most N instances of a particular task is running on that worker at a time.
For example, if I trigger 4 instances of DAG A right now, even with 4 workers and task_concurrency = 4 on the operator, I believe there's no guarantee that each worker will receive at most 1 instance of the task, and hence it could end up with e.g. 2 instances on worker #1 and 2 instances on worker #2.
Another heavy-handed solution would be reducing worker_concurrency, but that would restrict worker concurrency for all tasks & DAGs, and so isn't ideal as it's overly restrictive.
Said another way, this feature request is to basically combine task_concurrency on the operator and worker_concurrency to make a task-specific worker concurrency.
I primarily work with the CeleryExecutor; I'm not famiiiliar with the other non-local executors to know if this is a reasonable request for those executor types.
Thanks!
Thanks for opening your first issue here! Be sure to follow the issue template!
Thinking about this more - I guess this can be implemented with a separate queue and a separate worker, where that worker has its celery.worker_concurrency set to 1 and listens to this separate queue. That said, it's not the most convenient, since it involves spinning up a separate worker with its own config and its own queue.
I'll leave this open for the time being in case there's a better solution I'm missing!
separate queue and a separate worker, where that worker has its celery.worker_concurrency set to 1 and listens to this separate queue
I would say that this is the recommended way. It's a common practice to have separate queue for compute-heavy tasks. If you are using kubernetes - then once Airflow 2.0 is you can consider using CeleryKubernetesExecutor https://airflow.readthedocs.io/en/latest/executor/celery_kubernetes.html
Btw. I know that @mjpieters had some similar problem, not sure what was the final solution 馃槈
Sounds good - I guess I'll go with this. Thanks!
My problem was a little different, but the same approach could work.
You'd need to plug into Celery and direct jobs to workers with capacity using per-worker queues. To make this work across schedulers and workers, I've so far used Redis to share bookkeeping information and to ensure consistency when multiple clients update that information.
I'd set each worker capacity to its concurrency level, and default tasks to a cost of 1, your expensive tasks can then use larger values to reserve a certain 'chunk' of your worker resources.
You can re-route tasks in Celery by hooking into the task_routes configuration. If you set this to a function, that function is called for every routing decision. In Airflow, you can set this hook by supplying a custom CELERY_CONFIG dictionary:
from airflow.config_templates.default_celery import DEFAULT_CELERY_CONFIG
CELERY_CONFIG = {
**DEFAULT_CELERY_CONFIG,
"task_routes": "[module with route_task].route_tasks",
}
and in a separate module (as celery should not try to import it until after configuration is complete and the task router is actually needed):
def route_tasks(name, args, kwargs, options, task=None, **kw):
"""Custom task router
name: the Celery task name. Airflow has just one Celery task name.
args: positional arguments for the task. Airlow: args[0] is the args list for the 'airflow task run' command,
so set to ['airflow', 'task', 'run', dag_id, task_id, execution_date, ...]
kwargs: keyword arguments for the task. Airflow: always empty
options: the Celery task options dictionary. Contains existing routing information (such as 'queue').
returned information is merged with this dictionary, with this dictionary *taking precedence*.
task: Celery task object. Can be used to access the Celery app via task.app, etc.
If the return value is not falsey (None, {}, ''), it must either be a string (name of a queue), or a dictionary
with routing options.
*Note*: Airflow sets a default queue in options. Delete it from that dictionary if you want to redirect
a task.
"""
and in airflow.cfg, set celery_config_options:
[celery]
# ...
celery_config_options = [module with the celery config].CELERY_CONFIG
You can then use Celery signal handlers to maintain worker capacity. It'll depend on how you get your task 'size' data what hooks you'd need to use, but if I assume a hardcoded map then you'd use:
celeryd_after_setup to generate a worker queue name to listen to.worker_ready to add the worker queue to Redis with total worker capacity.worker_shutting_down to remove the worker from Redis altogether.task_postrun to return task size back to the worker capacity level.The task_router is responsible for reducing the available worker capacity, you want to do this as soon as you make a routing decision so no further tasks are sent to a worker that is already committed to a workload.
In Redis, store the capacity in a sorted set (ZADD worker-capacity [worker-queue] [worker capacity]) so you can quickly access the worker with the most capacity. Use Redis 5.0 or newer so you can use ZPOPMAX to get the least-loaded worker available. Unfortunately there is no way to both get the worker with most capacity and decrement its capacity in one command, so use a pipeline with WATCH:
def best_worker(redis, task_size):
worker_queue, capacity = None, 0
with r.pipeline() as pipe:
while True:
try:
pipe.watch(WORKER_CAPACITY_KEY)
best_worker_cap = pipe.zpopmax(WORKER_CAPACITY_KEY)
if best_worker_cap:
worker_queue, capacity = best_worker_cap[0]
if capacity < task_size:
# no workers with capacity available for this task
# decide what you need to do in this case. This returns None
# to use the default queue. We can't wait forever here as Airflow will
# time out the celery executor job eventually.
return None
pipe.multi()
# put the worker back into the sorted set, with adjusted capacity.
pipe.zadd(WORKER_CAPACITY_KEY, worker_queue, capacity - task_size)
pipe.execute()
return worker_queue.decode()
except WatchError:
# something else altered the WORKER_CAPACITY_KEY sorted set
# so retry.
Finally, if you are already using Redis is as your Celery broker, you can reuse the Celery connection pool for these tasks. This would differ if you are using RabbitMQ, you'll have to maintain your own connection pool or use a different method of sharing this information between different components.
To reuse the connection pool, get access to the celery app. This will depend on the specific signal handler or if this is inside the router function. In the latter, the task object has an app attribute, for example, while the worker signal hooks will be passed in a worker instance object, which again has an app attribute. Given the Celery app in app, you can then use:
with app.pool.acquire().channel() as channel:
redis = channel.client
You may have to increase the Celery broker_pool_limit configuration however, depending on how busy your cluster gets.
One of these days I may actually write that blog post on this subject I was planning, but the above will have to do for now.
One of these days I may actually write that blog post on this subject I was planning, but the above will have to do for now.
That would be awesome! Possibly we can put it somewhere in Airflow docs.
@dimberman take a look at that, I think this is a really cool stuff 馃殌
@mjpieters thanks for the great explanation!
Most helpful comment
My problem was a little different, but the same approach could work.
You'd need to plug into Celery and direct jobs to workers with capacity using per-worker queues. To make this work across schedulers and workers, I've so far used Redis to share bookkeeping information and to ensure consistency when multiple clients update that information.
I'd set each worker capacity to its concurrency level, and default tasks to a cost of 1, your expensive tasks can then use larger values to reserve a certain 'chunk' of your worker resources.
You can re-route tasks in Celery by hooking into the
task_routesconfiguration. If you set this to a function, that function is called for every routing decision. In Airflow, you can set this hook by supplying a customCELERY_CONFIGdictionary:and in a separate module (as celery should not try to import it until after configuration is complete and the task router is actually needed):
and in
airflow.cfg, setcelery_config_options:You can then use Celery signal handlers to maintain worker capacity. It'll depend on how you get your task 'size' data what hooks you'd need to use, but if I assume a hardcoded map then you'd use:
celeryd_after_setupto generate a worker queue name to listen to.worker_readyto add the worker queue to Redis with total worker capacity.worker_shutting_downto remove the worker from Redis altogether.task_postrunto return task size back to the worker capacity level.The
task_routeris responsible for reducing the available worker capacity, you want to do this as soon as you make a routing decision so no further tasks are sent to a worker that is already committed to a workload.In Redis, store the capacity in a sorted set (
ZADD worker-capacity [worker-queue] [worker capacity]) so you can quickly access the worker with the most capacity. Use Redis 5.0 or newer so you can useZPOPMAXto get the least-loaded worker available. Unfortunately there is no way to both get the worker with most capacity and decrement its capacity in one command, so use a pipeline withWATCH:Finally, if you are already using Redis is as your Celery broker, you can reuse the Celery connection pool for these tasks. This would differ if you are using RabbitMQ, you'll have to maintain your own connection pool or use a different method of sharing this information between different components.
To reuse the connection pool, get access to the celery app. This will depend on the specific signal handler or if this is inside the router function. In the latter, the
taskobject has anappattribute, for example, while the worker signal hooks will be passed in aworkerinstance object, which again has anappattribute. Given the Celery app inapp, you can then use:You may have to increase the Celery
broker_pool_limitconfiguration however, depending on how busy your cluster gets.One of these days I may actually write that blog post on this subject I was planning, but the above will have to do for now.