Module: celery_queue
Expand source code
# Copyright (C) 2023-present The Project Contributors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import multiprocessing
import os
from dataclasses import dataclass
from typing import Final
from uuid import UUID
from celery import Celery
from cl.runtime import Context
from cl.runtime.log.exceptions.user_error import UserError
from cl.runtime.log.log_entry import LogEntry
from cl.runtime.log.log_entry_level_enum import LogEntryLevelEnum
from cl.runtime.log.user_log_entry import UserLogEntry
from cl.runtime.primitive.datetime_util import DatetimeUtil
from cl.runtime.records.protocols import TDataDict
from cl.runtime.records.protocols import is_key
from cl.runtime.records.protocols import is_record
from cl.runtime.serialization.dict_serializer import DictSerializer
from cl.runtime.settings.context_settings import ContextSettings
from cl.runtime.settings.project_settings import ProjectSettings
from cl.runtime.tasks.task import Task
from cl.runtime.tasks.task_key import TaskKey
from cl.runtime.tasks.task_queue import TaskQueue
from cl.runtime.tasks.task_queue_key import TaskQueueKey
from cl.runtime.tasks.task_status_enum import TaskStatusEnum
CELERY_MAX_WORKERS = 4
CELERY_RUN_COMMAND_QUEUE: Final[str] = "run_command"
CELERY_MAX_RETRIES: Final[int] = 3
CELERY_TIME_LIMIT: Final[int] = 3600 * 2 # TODO: 2 hours (configure)
databases_dir = ProjectSettings.get_databases_dir()
context_id = ContextSettings.instance().context_id
# Get sqlite file name of celery broker based on database id in settings
celery_file = os.path.join(databases_dir, f"{context_id}.celery.sqlite")
celery_sqlite_uri = f"sqlalchemy+sqlite:///{celery_file}"
celery_app = Celery(
"worker",
broker=celery_sqlite_uri,
broker_connection_retry_on_startup=True,
)
celery_app.conf.task_track_started = True
context_serializer = DictSerializer()
"""Serializer for the context parameter of 'execute_task' method."""
@celery_app.task(max_retries=0) # Do not retry failed tasks
def execute_task(
task_id: str,
context_data: TDataDict,
) -> None:
"""Invoke 'run_task' method of the specified task."""
# Set is_deserialized flag in context data, will be used to skip some of the initialization code
context_data["is_deserialized"] = True
# Deserialize context from 'context_data' parameter to run with the same settings as the caller context
with context_serializer.deserialize_data(context_data) as context:
# Load and run the task
task_key = TaskKey(task_id=task_id)
task = context.load_one(Task, task_key)
task.run_task()
def celery_start_queue_callable(*, log_dir: str) -> None:
"""
Callable for starting the celery queue process.
Args:
log_dir: Directory where Celery console log file will be written
"""
# Redirect console output from celery to a log file
# TODO: Use an additional Logger handler instead
log_file_path = os.path.join(log_dir, "celery_queue.log")
# with open(log_file_path, "w") as log_file:
# os.dup2(log_file.fileno(), 1) # Redirect stdout (file descriptor 1)
# os.dup2(log_file.fileno(), 2) # Redirect stderr (file descriptor 2)
celery_app.worker_main(
argv=[
"-A",
"cl.runtime.tasks.celery.celery_queue",
"worker",
"--loglevel=info",
f"--autoscale={CELERY_MAX_WORKERS},1",
f"--pool=solo", # One concurrent task per worker, do not switch to prefork (not supported on Windows)
f"--concurrency=1", # Use only for prefork, one concurrent task per worker (similar to solo)
],
)
def celery_delete_existing_tasks() -> None:
"""Delete the existing Celery tasks (will exit when the current process exits)."""
# Remove sqlite file of celery broker if exists
if os.path.exists(celery_file):
os.remove(celery_file)
def celery_start_queue(*, log_dir: str) -> None:
"""
Start Celery workers (will exit when the current process exits).
Args:
log_dir: Directory where Celery console log file will be written
"""
worker_process = multiprocessing.Process(
target=celery_start_queue_callable, daemon=True, kwargs={"log_dir": log_dir}
)
worker_process.start()
@dataclass(slots=True, kw_only=True)
class CeleryQueue(TaskQueue):
"""Execute tasks using Celery."""
# max_workers: int = missing() # TODO: Implement support for max_workers
"""The maximum number of processes running concurrently."""
# TODO: @abstractmethod
def run_start_queue(self) -> None:
"""Start queue workers."""
# TODO: @abstractmethod
def run_stop_queue(self) -> None:
"""Cancel all active runs and stop queue workers."""
def submit_task(self, task: TaskKey):
# Get and serialize current context
context = Context.current()
context_data = context_serializer.serialize_data(context)
# Pass parameters to the Celery task signature
execute_task_signature = execute_task.s(
task.task_id,
context_data,
)
# Submit task to Celery with completed and error links
execute_task_signature.apply_async(
retry=False, # Do not retry in case the task fails
ignore_result=True, # TODO: Do not publish to the Celery result backend
)
Global variables
var context_serializer
-
Serializer for the context parameter of ‘execute_task’ method.
Functions
def celery_delete_existing_tasks() -> None
-
Delete the existing Celery tasks (will exit when the current process exits).
def celery_start_queue(*, log_dir: str) -> None
-
Start Celery workers (will exit when the current process exits).
Args
log_dir
- Directory where Celery console log file will be written
def celery_start_queue_callable(*, log_dir: str) -> None
-
Callable for starting the celery queue process.
Args
log_dir
- Directory where Celery console log file will be written
def execute_task(task_id: str, context_data: Dict[str, Union[Dict[str, ForwardRef('TDataField')], List[ForwardRef('TDataField')], str, float, bool, int, datetime.date, datetime.time, datetime.datetime, uuid.UUID, bytes, ForwardRef(None), enum.Enum]])
-
Invoke ‘run_task’ method of the specified task.
Classes
class CeleryQueue (*, queue_id: str = None, timeout_sec: int = 10)
-
Execute tasks using Celery.
Expand source code
@dataclass(slots=True, kw_only=True) class CeleryQueue(TaskQueue): """Execute tasks using Celery.""" # max_workers: int = missing() # TODO: Implement support for max_workers """The maximum number of processes running concurrently.""" # TODO: @abstractmethod def run_start_queue(self) -> None: """Start queue workers.""" # TODO: @abstractmethod def run_stop_queue(self) -> None: """Cancel all active runs and stop queue workers.""" def submit_task(self, task: TaskKey): # Get and serialize current context context = Context.current() context_data = context_serializer.serialize_data(context) # Pass parameters to the Celery task signature execute_task_signature = execute_task.s( task.task_id, context_data, ) # Submit task to Celery with completed and error links execute_task_signature.apply_async( retry=False, # Do not retry in case the task fails ignore_result=True, # TODO: Do not publish to the Celery result backend )
Ancestors
- TaskQueue
- TaskQueueKey
- KeyMixin
- abc.ABC
Static methods
def get_key_type() -> Type
-
Inherited from:
TaskQueue
.get_key_type
Return key type even when called from a record.
Fields
var queue_id -> str
-
Inherited from:
TaskQueue
.queue_id
Unique task queue identifier.
var timeout_sec -> int
-
Inherited from:
TaskQueue
.timeout_sec
Optional timeout in seconds, queue will stop after reaching this timeout.
Methods
def run_start_queue(self) -> None
-
Start queue workers.
def run_stop_queue(self) -> None
-
Cancel all active runs and stop queue workers.
def submit_task(self, task: TaskKey)