From b5caa8a8f17ec93414e47d812d6e834919d89356 Mon Sep 17 00:00:00 2001 From: Robin Van de Merghel Date: Fri, 13 Jun 2025 13:41:30 +0200 Subject: [PATCH 1/9] feat: Add pilot management: create/delete/patch and query --- .../src/diracx/client/_generated/_client.py | 5 +- .../diracx/client/_generated/aio/_client.py | 5 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 586 +++++++++++++++ .../_generated/aio/operations/_patch.py | 2 + .../client/_generated/models/__init__.py | 8 + .../diracx/client/_generated/models/_enums.py | 13 + .../client/_generated/models/_models.py | 199 +++++ .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 694 ++++++++++++++++++ .../client/_generated/operations/_patch.py | 2 + .../src/diracx/client/patches/pilots/aio.py | 53 ++ .../diracx/client/patches/pilots/common.py | 147 ++++ .../src/diracx/client/patches/pilots/sync.py | 53 ++ diracx-core/src/diracx/core/exceptions.py | 25 +- diracx-core/src/diracx/core/models/job.py | 2 +- diracx-core/src/diracx/core/models/pilot.py | 33 + diracx-db/src/diracx/db/sql/__init__.py | 2 +- diracx-db/src/diracx/db/sql/dummy/db.py | 8 +- diracx-db/src/diracx/db/sql/job/db.py | 3 +- .../src/diracx/db/sql/pilot_agents/db.py | 45 -- .../sql/{pilot_agents => pilots}/__init__.py | 0 diracx-db/src/diracx/db/sql/pilots/db.py | 241 ++++++ .../db/sql/{pilot_agents => pilots}/schema.py | 5 +- diracx-db/src/diracx/db/sql/utils/__init__.py | 16 + .../pilot_agents/test_pilot_agents_db.py | 30 - .../{pilot_agents => pilots}/__init__.py | 0 .../tests/pilots/test_pilot_management.py | 193 +++++ diracx-db/tests/pilots/test_query.py | 292 ++++++++ diracx-db/tests/pilots/utils.py | 146 ++++ diracx-db/tests/test_dummy_db.py | 1 + .../src/diracx/logic/pilots/__init__.py | 0 .../src/diracx/logic/pilots/management.py | 122 +++ diracx-logic/src/diracx/logic/pilots/query.py | 183 +++++ diracx-routers/pyproject.toml | 38 +- .../src/diracx/routers/pilots/__init__.py | 13 + .../diracx/routers/pilots/access_policies.py | 125 ++++ .../src/diracx/routers/pilots/management.py | 257 +++++++ .../src/diracx/routers/pilots/query.py | 165 +++++ .../tests/pilots/test_pilot_creation.py | 281 +++++++ diracx-routers/tests/pilots/test_query.py | 406 ++++++++++ docs/dev/explanations/pilots.md | 20 + .../client/_generated/models/__init__.py | 8 + .../client/_generated/models/_enums.py | 13 + .../client/_generated/models/_models.py | 199 +++++ pixi.lock | 46 +- 46 files changed, 4558 insertions(+), 131 deletions(-) create mode 100644 diracx-client/src/diracx/client/patches/pilots/aio.py create mode 100644 diracx-client/src/diracx/client/patches/pilots/common.py create mode 100644 diracx-client/src/diracx/client/patches/pilots/sync.py create mode 100644 diracx-core/src/diracx/core/models/pilot.py delete mode 100644 diracx-db/src/diracx/db/sql/pilot_agents/db.py rename diracx-db/src/diracx/db/sql/{pilot_agents => pilots}/__init__.py (100%) create mode 100644 diracx-db/src/diracx/db/sql/pilots/db.py rename diracx-db/src/diracx/db/sql/{pilot_agents => pilots}/schema.py (94%) delete mode 100644 diracx-db/tests/pilot_agents/test_pilot_agents_db.py rename diracx-db/tests/{pilot_agents => pilots}/__init__.py (100%) create mode 100644 diracx-db/tests/pilots/test_pilot_management.py create mode 100644 diracx-db/tests/pilots/test_query.py create mode 100644 diracx-db/tests/pilots/utils.py create mode 100644 diracx-logic/src/diracx/logic/pilots/__init__.py create mode 100644 diracx-logic/src/diracx/logic/pilots/management.py create mode 100644 diracx-logic/src/diracx/logic/pilots/query.py create mode 100644 diracx-routers/src/diracx/routers/pilots/__init__.py create mode 100644 diracx-routers/src/diracx/routers/pilots/access_policies.py create mode 100644 diracx-routers/src/diracx/routers/pilots/management.py create mode 100644 diracx-routers/src/diracx/routers/pilots/query.py create mode 100644 diracx-routers/tests/pilots/test_pilot_creation.py create mode 100644 diracx-routers/tests/pilots/test_query.py create mode 100644 docs/dev/explanations/pilots.md diff --git a/diracx-client/src/diracx/client/_generated/_client.py b/diracx-client/src/diracx/client/_generated/_client.py index caf48034f..6984754b4 100644 --- a/diracx-client/src/diracx/client/_generated/_client.py +++ b/diracx-client/src/diracx/client/_generated/_client.py @@ -15,7 +15,7 @@ from . import models as _models from ._configuration import DiracConfiguration from ._utils.serialization import Deserializer, Serializer -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/diracx-client/src/diracx/client/_generated/aio/_client.py b/diracx-client/src/diracx/client/_generated/aio/_client.py index 79ab383a9..f505a4f10 100644 --- a/diracx-client/src/diracx/client/_generated/aio/_client.py +++ b/diracx-client/src/diracx/client/_generated/aio/_client.py @@ -15,7 +15,7 @@ from .. import models as _models from .._utils.serialization import Deserializer, Serializer from ._configuration import DiracConfiguration -from .operations import AuthOperations, ConfigOperations, JobsOperations, WellKnownOperations +from .operations import AuthOperations, ConfigOperations, JobsOperations, PilotsOperations, WellKnownOperations class Dirac: # pylint: disable=client-accepts-api-version-keyword @@ -29,6 +29,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype config: _generated.aio.operations.ConfigOperations :ivar jobs: JobsOperations operations :vartype jobs: _generated.aio.operations.JobsOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -65,6 +67,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.auth = AuthOperations(self._client, self._config, self._serialize, self._deserialize) self.config = ConfigOperations(self._client, self._config, self._serialize, self._deserialize) self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py index 6be34fb8a..c23cd2b3f 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 8aee57b46..0e46aee29 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -51,6 +51,12 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, + build_pilots_add_pilot_stamps_request, + build_pilots_delete_pilots_request, + build_pilots_get_pilot_jobs_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_fields_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -2319,3 +2325,583 @@ async def submit_jdl_jobs(self, body: Union[list[str], IO[bytes]], **kwargs: Any return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots( + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_fields( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_fields( + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace_async + async def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index a408e57d2..0c70ce3e9 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations +from ....patches.pilots.aio import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 14b5195d4..8e1dbe20d 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -16,6 +16,8 @@ BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyPilotsAddPilotStamps, + BodyPilotsUpdatePilotFields, GroupInfo, HTTPValidationError, HeartbeatData, @@ -26,6 +28,7 @@ JobStatusUpdate, Metadata, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -48,6 +51,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -63,6 +67,8 @@ "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyPilotsAddPilotStamps", + "BodyPilotsUpdatePilotFields", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -73,6 +79,7 @@ "JobStatusUpdate", "Metadata", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -92,6 +99,7 @@ "VectorSearchSpec", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/diracx-client/src/diracx/client/_generated/models/_enums.py b/diracx-client/src/diracx/client/_generated/models/_enums.py index b83473639..849d3252a 100644 --- a/diracx-client/src/diracx/client/_generated/models/_enums.py +++ b/diracx-client/src/diracx/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 888ec3b8a..f592def1f 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -184,6 +184,109 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class GroupInfo(_serialization.Model): """GroupInfo. @@ -929,6 +1032,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. diff --git a/diracx-client/src/diracx/client/_generated/operations/__init__.py b/diracx-client/src/diracx/client/_generated/operations/__init__.py index 6be34fb8a..c23cd2b3f 100644 --- a/diracx-client/src/diracx/client/_generated/operations/__init__.py +++ b/diracx-client/src/diracx/client/_generated/operations/__init__.py @@ -14,6 +14,7 @@ from ._operations import AuthOperations # type: ignore from ._operations import ConfigOperations # type: ignore from ._operations import JobsOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -24,6 +25,7 @@ "AuthOperations", "ConfigOperations", "JobsOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 11ffdcff7..954a82b0a 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -565,6 +565,124 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) +def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request( + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any +) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + if pilot_stamps is not None: + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + if age_in_days is not None: + _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") + if delete_only_aborted is not None: + _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_get_pilot_jobs_request( + *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any +) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/jobs" + + # Construct parameters + if pilot_stamp is not None: + _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") + if job_id is not None: + _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") + + # Construct headers + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int") + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2818,3 +2936,579 @@ def submit_jdl_jobs(self, body: Union[list[str], IO[bytes]], **kwargs: Any) -> l return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def add_pilot_stamps( + self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: + """Add Pilot Stamps. + + Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + + :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + + _request = build_pilots_add_pilot_stamps_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, + *, + pilot_stamps: Optional[List[str]] = None, + age_in_days: Optional[int] = None, + delete_only_aborted: bool = False, + **kwargs: Any + ) -> None: + """Delete Pilots. + + Endpoint to delete a pilot. + + Two features: + + + #. Or you provide pilot_stamps, so you can delete pilots by their stamp + #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + + :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :paramtype pilot_stamps: list[str] + :keyword age_in_days: The number of days that define the maximum age of pilots to be + deleted.Pilots older than this age will be considered for deletion. Default value is None. + :paramtype age_in_days: int + :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is + 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by + default as True to avoid any mistake.This flag is only used for deletion by time. Default value + is False. + :paramtype delete_only_aborted: bool + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_fields( + self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_fields( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Fields. + + Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + + :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + + _request = build_pilots_update_pilot_fields_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @distributed_trace + def get_pilot_jobs( + self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any + ) -> List[int]: + """Get Pilot Jobs. + + Endpoint only for admins, to get jobs of a pilot. + + :keyword pilot_stamp: The stamp of the pilot. Default value is None. + :paramtype pilot_stamp: str + :keyword job_id: The ID of the job. Default value is None. + :paramtype job_id: int + :return: list of int + :rtype: list[int] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[List[int]] = kwargs.pop("cls", None) + + _request = build_pilots_get_pilot_jobs_request( + pilot_stamp=pilot_stamp, + job_id=job_id, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("[int]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> List[Dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Show information suitable for plotting. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index b7b8c67fa..b14e98b84 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -11,10 +11,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations +from ...patches.pilots.sync import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py new file mode 100644 index 000000000..ac533a67c --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator_async import distributed_trace_async + +from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace_async + async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().search(**make_search_body(**kwargs)) + + @distributed_trace_async + async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return await super().summary(**make_summary_body(**kwargs)) + + @distributed_trace_async + async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace_async + async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py new file mode 100644 index 000000000..fd54786f2 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -0,0 +1,147 @@ +"""Utilities which are common to the sync and async pilots operator patches.""" + +from __future__ import annotations + +__all__ = [ + "make_search_body", + "SearchKwargs", + "make_summary_body", + "SummaryKwargs", + "AddPilotStampsKwargs", + "make_add_pilot_stamps_body", + "UpdatePilotFieldsKwargs", + "make_update_pilot_fields_body" +] + +import json +from io import BytesIO +from typing import Any, IO, TypedDict, Unpack, cast, Literal + +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.search import SearchSpec + + +class ResponseExtra(TypedDict, total=False): + content_type: str + headers: dict[str, str] + params: dict[str, str] + cls: Any + + +# ------------------ Search ------------------ +class SearchBody(TypedDict, total=False): + parameters: list[str] | None + search: list[SearchSpec] | None + sort: list[str] | None + + +class SearchExtra(ResponseExtra, total=False): + page: int + per_page: int + + +class SearchKwargs(SearchBody, SearchExtra): ... + + +class UnderlyingSearchArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: + body: SearchBody = {} + for key in SearchBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["parameters", "search", "sort"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSearchArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(SearchExtra, kwargs)) + return result + +# ------------------ Summary ------------------ + +class SummaryBody(TypedDict, total=False): + grouping: list[str] + search: list[str] + + +class SummaryKwargs(SummaryBody, ResponseExtra): ... + + +class UnderlyingSummaryArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + + +def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: + body: SummaryBody = {} + for key in SummaryBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["grouping", "search"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingSummaryArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ AddPilotStamps ------------------ + +class AddPilotStampsBody(TypedDict, total=False): + pilot_stamps: list[str] + grid_type: str + grid_site: str + pilot_references: dict[str, str] + pilot_status: PilotStatus + vo: str + +class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... + +class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs: + body: AddPilotStampsBody = {} + for key in AddPilotStampsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result + +# ------------------ UpdatePilotFields ------------------ + +class UpdatePilotFieldsBody(TypedDict, total=False): + pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + +class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ... + +class UnderlyingUpdatePilotFields(ResponseExtra, total=False): + # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite + # the code being generated to support IO[bytes] | bytes. + body: IO[bytes] + +def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields: + body: UpdatePilotFieldsBody = {} + for key in UpdatePilotFieldsBody.__optional_keys__: + if key not in kwargs: + continue + key = cast(Literal["pilot_stamps_to_fields_mapping"], key) + value = kwargs.pop(key) + if value is not None: + body[key] = value + result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result.update(cast(ResponseExtra, kwargs)) + return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py new file mode 100644 index 000000000..744cee161 --- /dev/null +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -0,0 +1,53 @@ +"""Patches for the autorest-generated pilots client. + +This file can be used to customize the generated code for the pilots client. +When adding new classes to this file, make sure to also add them to the +__all__ list in the corresponding file in the patches directory. +""" + +from __future__ import annotations + +__all__ = [ + "PilotsOperations", +] + +from typing import Any, Unpack + +from azure.core.tracing.decorator import distributed_trace + +from ..._generated.operations._operations import PilotsOperations as _PilotsOperations +from .common import ( + make_search_body, + make_summary_body, + make_add_pilot_stamps_body, + make_update_pilot_fields_body, + SearchKwargs, + SummaryKwargs, + AddPilotStampsKwargs, + UpdatePilotFieldsKwargs +) + +# We're intentionally ignoring overrides here because we want to change the interface. +# mypy: disable-error-code=override + + +class PilotsOperations(_PilotsOperations): + @distributed_trace + def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().search(**make_search_body(**kwargs)) + + @distributed_trace + def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: + """TODO""" + return super().summary(**make_summary_body(**kwargs)) + + @distributed_trace + def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: + """TODO""" + return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) + + @distributed_trace + def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: + """TODO""" + return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 54d7c240d..19d8d5a41 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -15,6 +15,7 @@ class DiracError(RuntimeError): def __init__(self, detail: str = "Unknown"): self.detail = detail + super().__init__(detail) class AuthorizationError(DiracError): ... @@ -49,19 +50,19 @@ class InvalidQueryError(DiracError): class TokenNotFoundError(DiracError): - def __init__(self, jti: str, detail: str | None = None): + def __init__(self, jti: str, detail: str = ""): self.jti: str = jti super().__init__(f"Token {jti} not found" + (f" ({detail})" if detail else "")) class JobNotFoundError(DiracError): - def __init__(self, job_id: int, detail: str | None = None): + def __init__(self, job_id: int, detail: str = ""): self.job_id: int = job_id super().__init__(f"Job {job_id} not found" + (f" ({detail})" if detail else "")) class SandboxNotFoundError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -71,7 +72,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyAssignedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -81,7 +82,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class SandboxAlreadyInsertedError(DiracError): - def __init__(self, pfn: str, se_name: str, detail: str | None = None): + def __init__(self, pfn: str, se_name: str, detail: str = ""): self.pfn: str = pfn self.se_name: str = se_name super().__init__( @@ -91,7 +92,7 @@ def __init__(self, pfn: str, se_name: str, detail: str | None = None): class JobError(DiracError): - def __init__(self, job_id, detail: str | None = None): + def __init__(self, job_id, detail: str = ""): self.job_id: int = job_id super().__init__( f"Error concerning job {job_id}" + (f" ({detail})" if detail else "") @@ -100,3 +101,15 @@ def __init__(self, job_id, detail: str | None = None): class NotReadyError(DiracError): """Tried to access a value which is asynchronously loaded but not yet available.""" + + +class PilotNotFoundError(DiracError): + """At least one pilot is not found.""" + + +class PilotAlreadyExistsError(DiracError): + """At least one pilot already exists, we avoid collitions.""" + + +class PilotAlreadyAssociatedWithJobError(DiracError): + """We can't associate a pilot with the same job twice.""" diff --git a/diracx-core/src/diracx/core/models/job.py b/diracx-core/src/diracx/core/models/job.py index ec098c2c6..94e6452b9 100644 --- a/diracx-core/src/diracx/core/models/job.py +++ b/diracx-core/src/diracx/core/models/job.py @@ -6,7 +6,7 @@ from __future__ import annotations from enum import StrEnum -from typing import Literal +from typing import Literal, Optional from pydantic import BaseModel, Field, field_validator diff --git a/diracx-core/src/diracx/core/models/pilot.py b/diracx-core/src/diracx/core/models/pilot.py new file mode 100644 index 000000000..7abba1378 --- /dev/null +++ b/diracx-core/src/diracx/core/models/pilot.py @@ -0,0 +1,33 @@ +"""Pilot-related models shared between client, logic, and services.""" + +from __future__ import annotations + +from enum import StrEnum + +from pydantic import BaseModel + + +class PilotStatus(StrEnum): + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + +class PilotFieldsMapping(BaseModel, extra="forbid"): + """All the fields that a user can modify on a Pilot (except PilotStamp).""" + + PilotStamp: str + StatusReason: str | None = None + Status: PilotStatus | None = None + BenchMark: float | None = None + DestinationSite: str | None = None + Queue: str | None = None + GridSite: str | None = None + GridType: str | None = None + AccountingSent: bool | None = None + CurrentJobID: int | None = None diff --git a/diracx-db/src/diracx/db/sql/__init__.py b/diracx-db/src/diracx/db/sql/__init__.py index 3be3af8a3..e2f141ad5 100644 --- a/diracx-db/src/diracx/db/sql/__init__.py +++ b/diracx-db/src/diracx/db/sql/__init__.py @@ -12,6 +12,6 @@ from .auth.db import AuthDB from .job.db import JobDB from .job_logging.db import JobLoggingDB -from .pilot_agents.db import PilotAgentsDB +from .pilots.db import PilotAgentsDB from .sandbox_metadata.db import SandboxMetadataDB from .task_queue.db import TaskQueueDB diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 76e8db07b..966b6381e 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -3,6 +3,7 @@ from sqlalchemy import insert from uuid_utils import UUID +from diracx.core.models import SearchSpec from diracx.db.sql.utils import BaseSQLDB from .schema import Base as DummyDBBase @@ -20,8 +21,11 @@ class DummyDB(BaseSQLDB): # This needs to be here for the BaseSQLDB to create the engine metadata = DummyDBBase.metadata - async def summary(self, group_by, search) -> list[dict[str, str | int]]: - return await self._summary(Cars, group_by, search) + async def summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self._summary(table=Cars, group_by=group_by, search=search) async def insert_owner(self, name: str) -> int: stmt = insert(Owners).values(name=name) diff --git a/diracx-db/src/diracx/db/sql/job/db.py b/diracx-db/src/diracx/db/sql/job/db.py index bb28aa5cf..00598663e 100644 --- a/diracx-db/src/diracx/db/sql/job/db.py +++ b/diracx-db/src/diracx/db/sql/job/db.py @@ -15,8 +15,7 @@ from diracx.core.models.job import JobCommand from diracx.core.models.search import SearchSpec, SortSpec -from ..utils import BaseSQLDB, _get_columns -from ..utils.functions import utcnow +from ..utils import BaseSQLDB, _get_columns, utcnow from .schema import ( HeartBeatLoggingInfo, InputData, diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/db.py b/diracx-db/src/diracx/db/sql/pilot_agents/db.py deleted file mode 100644 index 954f081b1..000000000 --- a/diracx-db/src/diracx/db/sql/pilot_agents/db.py +++ /dev/null @@ -1,45 +0,0 @@ -from __future__ import annotations - -from datetime import datetime, timezone - -from sqlalchemy import insert - -from ..utils import BaseSQLDB -from .schema import PilotAgents, PilotAgentsDBBase - - -class PilotAgentsDB(BaseSQLDB): - """PilotAgentsDB class is a front-end to the PilotAgents Database.""" - - metadata = PilotAgentsDBBase.metadata - - async def add_pilot_references( - self, - pilot_ref: list[str], - vo: str, - grid_type: str = "DIRAC", - pilot_stamps: dict | None = None, - ) -> None: - if pilot_stamps is None: - pilot_stamps = {} - - now = datetime.now(tz=timezone.utc) - - # Prepare the list of dictionaries for bulk insertion - values = [ - { - "PilotJobReference": ref, - "VO": vo, - "GridType": grid_type, - "SubmissionTime": now, - "LastUpdateTime": now, - "Status": "Submitted", - "PilotStamp": pilot_stamps.get(ref, ""), - } - for ref in pilot_ref - ] - - # Insert multiple rows in a single execute call - stmt = insert(PilotAgents).values(values) - await self.conn.execute(stmt) - return diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/__init__.py b/diracx-db/src/diracx/db/sql/pilots/__init__.py similarity index 100% rename from diracx-db/src/diracx/db/sql/pilot_agents/__init__.py rename to diracx-db/src/diracx/db/sql/pilots/__init__.py diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py new file mode 100644 index 000000000..2cdf6bf39 --- /dev/null +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -0,0 +1,241 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +from sqlalchemy import bindparam +from sqlalchemy.exc import IntegrityError +from sqlalchemy.sql import delete, insert, update + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, + PilotNotFoundError, +) +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.search import SearchSpec, SortSpec + +from ..utils import ( + BaseSQLDB, +) +from .schema import ( + JobToPilotMapping, + PilotAgents, + PilotAgentsDBBase, + PilotOutput, +) + + +class PilotAgentsDB(BaseSQLDB): + """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + + metadata = PilotAgentsDBBase.metadata + + # ----------------------------- Insert Functions ----------------------------- + + async def add_pilots( + self, + pilot_stamps: list[str], + vo: str, + grid_type: str = "DIRAC", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: dict[str, str] | None = None, + status: str = PilotStatus.SUBMITTED, + ): + """Bulk add pilots in the DB. + + If we can't find a pilot_reference associated with a stamp, we take the stamp by default. + """ + if pilot_references is None: + pilot_references = {} + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + values = [ + { + "PilotJobReference": pilot_references.get(stamp, stamp), + "VO": vo, + "GridType": grid_type, + "GridSite": grid_site, + "DestinationSite": destination_site, + "SubmissionTime": now, + "LastUpdateTime": now, + "Status": status, + "PilotStamp": stamp, + } + for stamp in pilot_stamps + ] + + # Insert multiple rows in a single execute call and use 'returning' to get primary keys + stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key + + await self.conn.execute(stmt) + + async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): + """Associate a pilot with jobs. + + job_to_pilot_mapping format: + ```py + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + ] + ``` + + Raises: + - PilotNotFoundError if a pilot_id is not associated with a pilot. + - PilotAlreadyAssociatedWithJobError if the pilot is already associated with one of the given jobs. + - NotImplementedError if the integrity error is not caught. + + **Important note**: We assume that a job exists. + + """ + # Insert multiple rows in a single execute call + stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping) + + try: + await self.conn.execute(stmt) + except IntegrityError as e: + if "foreign key" in str(e.orig).lower(): + raise PilotNotFoundError( + detail="at least one of these pilots do not exist", + ) from e + + if ( + "duplicate entry" in str(e.orig).lower() + or "unique constraint" in str(e.orig).lower() + ): + raise PilotAlreadyAssociatedWithJobError( + detail="at least one of these pilots is already associated with a given job." + ) from e + + # Other errors to catch + raise NotImplementedError( + "Engine Specific error not caught" + str(e) + ) from e + + # ----------------------------- Delete Functions ----------------------------- + + async def delete_pilots(self, pilot_ids: list[int]): + """Destructive function. Delete pilots.""" + stmt = delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + async def remove_jobs_from_pilots(self, pilot_ids: list[int]): + """Destructive function. De-associate jobs and pilots.""" + stmt = delete(JobToPilotMapping).where( + JobToPilotMapping.pilot_id.in_(pilot_ids) + ) + + await self.conn.execute(stmt) + + async def delete_pilot_logs(self, pilot_ids: list[int]): + """Destructive function. Remove logs from pilots.""" + stmt = delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) + + await self.conn.execute(stmt) + + # ----------------------------- Update Functions ----------------------------- + + async def update_pilot_fields( + self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] + ): + """Bulk update pilots with a mapping. + + pilot_stamps_to_fields_mapping format: + ```py + [ + { + "PilotStamp": pilot_stamp, + "BenchMark": bench_mark, + "StatusReason": pilot_reason, + "AccountingSent": accounting_sent, + "Status": status, + "CurrentJobID": current_job_id, + "Queue": queue, + ... + } + ] + ``` + + The mapping helps to update multiple fields at a time. + + Raises PilotNotFoundError if one of the pilots is not found. + """ + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp == bindparam("b_pilot_stamp")) + .values( + { + key: bindparam(key) + for key in pilot_stamps_to_fields_mapping[0] + .model_dump(exclude_none=True) + .keys() + if key != "PilotStamp" + } + ) + ) + + values = [ + { + **{"b_pilot_stamp": mapping.PilotStamp}, + **mapping.model_dump(exclude={"PilotStamp"}, exclude_none=True), + } + for mapping in pilot_stamps_to_fields_mapping + ] + + res = await self.conn.execute(stmt, values) + + if res.rowcount != len(pilot_stamps_to_fields_mapping): + raise PilotNotFoundError("at least one of the given pilot does not exist.") + + # ----------------------------- Search Functions ----------------------------- + + async def search_pilots( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for pilot information in the database.""" + return await self._search( + table=PilotAgents, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def search_pilot_to_job_mapping( + self, + parameters: list[str] | None, + search: list[SearchSpec], + sorts: list[SortSpec], + *, + distinct: bool = False, + per_page: int = 100, + page: int | None = None, + ) -> tuple[int, list[dict[Any, Any]]]: + """Search for jobs that are associated with pilots.""" + return await self._search( + table=JobToPilotMapping, + parameters=parameters, + search=search, + sorts=sorts, + distinct=distinct, + per_page=per_page, + page=page, + ) + + async def pilot_summary( + self, group_by: list[str], search: list[SearchSpec] + ) -> list[dict[str, str | int]]: + """Get a summary of the pilots.""" + return await self._summary(table=PilotAgents, group_by=group_by, search=search) diff --git a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py similarity index 94% rename from diracx-db/src/diracx/db/sql/pilot_agents/schema.py rename to diracx-db/src/diracx/db/sql/pilots/schema.py index 770b62b79..4e0fbb9b2 100644 --- a/diracx-db/src/diracx/db/sql/pilot_agents/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -18,6 +18,7 @@ str255, ) from diracx.db.sql.utils.types import SmarterDateTime +from diracx.core.models.pilot import PilotStatus class PilotAgentsDBBase(DeclarativeBase): @@ -54,14 +55,14 @@ class PilotAgents(PilotAgentsDBBase): last_update_time: Mapped[Optional[datetime]] = mapped_column( "LastUpdateTime", SmarterDateTime ) - status: Mapped[str32] = mapped_column("Status", default="Unknown") + status: Mapped[str32] = mapped_column("Status", default=PilotStatus.UNKNOWN) status_reason: Mapped[str255] = mapped_column("StatusReason", default="Unknown") accounting_sent: Mapped[bool] = mapped_column( "AccountingSent", EnumBackedBool(), default=False ) - __table_args__ = ( Index("PilotJobReference", "PilotJobReference"), + Index("PilotStamp", "PilotStamp"), Index("Status", "Status"), Index("Statuskey", "GridSite", "DestinationSite", "Status"), ) diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index ad262186b..ab29c76e8 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -43,3 +43,19 @@ str512, str1024, ) + +__all__ = ( + "_get_columns", + "apply_search_filters", + "apply_sort_constraints", + "BaseSQLDB", + "Column", + "DateNowColumn", + "EnumBackedBool", + "EnumColumn", + "hash", + "NullColumn", + "substract_date", + "SQLDBUnavailableError", + "utcno", +) diff --git a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py b/diracx-db/tests/pilot_agents/test_pilot_agents_db.py deleted file mode 100644 index 3ca989885..000000000 --- a/diracx-db/tests/pilot_agents/test_pilot_agents_db.py +++ /dev/null @@ -1,30 +0,0 @@ -from __future__ import annotations - -import pytest - -from diracx.db.sql.pilot_agents.db import PilotAgentsDB - - -@pytest.fixture -async def pilot_agents_db(tmp_path) -> PilotAgentsDB: - agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") - async with agents_db.engine_context(): - async with agents_db.engine.begin() as conn: - await conn.run_sync(agents_db.metadata.create_all) - yield agents_db - - -async def test_insert_and_select(pilot_agents_db: PilotAgentsDB): - async with pilot_agents_db as pilot_agents_db: - # Add a pilot reference - refs = [f"ref_{i}" for i in range(10)] - stamps = [f"stamp_{i}" for i in range(10)] - stamp_dict = dict(zip(refs, stamps)) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=stamp_dict - ) - - await pilot_agents_db.add_pilot_references( - refs, "test_vo", grid_type="DIRAC", pilot_stamps=None - ) diff --git a/diracx-db/tests/pilot_agents/__init__.py b/diracx-db/tests/pilots/__init__.py similarity index 100% rename from diracx-db/tests/pilot_agents/__init__.py rename to diracx-db/tests/pilots/__init__.py diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py new file mode 100644 index 000000000..2adabb0d2 --- /dev/null +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -0,0 +1,193 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest + +from diracx.core.exceptions import ( + PilotAlreadyAssociatedWithJobError, +) +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.db.sql.pilots.db import PilotAgentsDB + +from .utils import ( + add_stamps, # noqa: F401 + create_old_pilots_environment, # noqa: F401 + create_timed_pilots, # noqa: F401 + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +@pytest.mark.asyncio +async def test_insert_and_select(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(10)] + stamps = [f"stamp_{i}" for i in range(10)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Accept duplicates because it is checked by the logic + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None + ) + + +@pytest.mark.asyncio +async def test_insert_and_delete(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i}" for i in range(2)] + stamps = [f"stamp_{i}" for i in range(2)] + pilot_references = dict(zip(stamps, refs)) + + await pilot_db.add_pilots( + stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references + ) + + # Works, the pilots exists + res = await get_pilots_by_stamp(pilot_db, [stamps[0]]) + await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + # We delete the first pilot + await pilot_db.delete_pilots([res[0]["PilotID"]]) + + # We get the 2nd pilot that is not delete (no error) + await get_pilots_by_stamp(pilot_db, [stamps[1]]) + # We get the 1st pilot that is delete (error) + + assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) + + +@pytest.mark.asyncio +async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Assert values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 0.0 + assert pilot["Status"] == PilotStatus.SUBMITTED + assert pilot["StatusReason"] == "Unknown" + assert not pilot["AccountingSent"] + + # + # Modify a pilot, then check if every change is done + # + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ) + ] + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + + # Set values + assert pilot["VO"] == MAIN_VO + assert pilot["PilotStamp"] == pilot_stamp + assert pilot["GridType"] == "grid-type" + assert pilot["BenchMark"] == 1.0 + assert pilot["Status"] == PilotStatus.WAITING + assert pilot["StatusReason"] == "NewReason" + assert pilot["AccountingSent"] + + +@pytest.mark.asyncio +async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): + """We will proceed in few steps. + + 1. Create a pilot + 2. Verify that he is not associated with any job + 3. Associate with jobs + 4. Verify that he is associate with this job + 5. Associate with jobs that he already has and two that he has not + 6. Associate with jobs that he has not, but were involved in a crash + """ + async with pilot_db as pilot_db: + pilot_stamp = "stamp-test" + # Add pilot + await pilot_db.add_pilots( + vo=MAIN_VO, + pilot_stamps=[pilot_stamp], + grid_type="grid-type", + ) + + res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) + assert len(res) == 1 + pilot = res[0] + pilot_id = pilot["PilotID"] + + # Verify that he has no jobs + assert len(await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id)) == 0 + + now = datetime.now(tz=timezone.utc) + + # Associate pilot with jobs + pilot_jobs = [1, 2, 3] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Verify that he has all jobs + db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) + # We test both length and if every job is included if for any reason we have duplicates + assert all(job in db_jobs for job in pilot_jobs) + assert len(pilot_jobs) == len(db_jobs) + + # Associate pilot with a job that he already has, and one that he has not + pilot_jobs = [10, 1, 5] + with pytest.raises(PilotAlreadyAssociatedWithJobError): + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + + # Associate pilot with jobs that he has not, but was previously in an error + # To test that the rollback worked + pilot_jobs = [5, 10] + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} + for job_id in pilot_jobs + ] + await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py new file mode 100644 index 000000000..d1e5b1da3 --- /dev/null +++ b/diracx-db/tests/pilots/test_query.py @@ -0,0 +1,292 @@ +from __future__ import annotations + +import pytest + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, SortDirection, SortSpec, VectorSearchOperator, VectorSearchSpec +from diracx.db.sql.pilots.db import PilotAgentsDB + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +async def pilot_db(tmp_path): + agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with agents_db.engine_context(): + async with agents_db.engine.begin() as conn: + await conn.run_sync(agents_db.metadata.create_all) + yield agents_db + + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_db(pilot_db): + async with pilot_db as pilot_db: + # Add pilots + refs = [f"ref_{i + 1}" for i in range(N)] + stamps = [f"stamp_{i + 1}" for i in range(N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await pilot_db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + await pilot_db.update_pilot_fields( + [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ) + for i, pilot_stamp in enumerate(stamps) + ] + ) + + yield pilot_db + + +async def test_search_parameters(populated_pilot_db): + """Test that we can search specific parameters for pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific parameter: PilotID + total, result = await pilot_db.search_pilots(["PilotID"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + + # Search a specific parameter: Status + total, result = await pilot_db.search_pilots(["Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"Status"} + + # Search for multiple parameters: PilotID, Status + total, result = await pilot_db.search_pilots(["PilotID", "Status"], [], []) + assert total == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + + # Search for a specific parameter but use distinct: Status + total, result = await pilot_db.search_pilots(["Status"], [], [], distinct=True) + assert total == len(PILOT_STATUSES) + assert result + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + total, result = await pilot_db.search_pilots(["Dummy"], [], []) + + +async def test_search_conditions(populated_pilot_db): + """Test that we can search for specific pilots in the database.""" + async with populated_pilot_db as pilot_db: + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert not result + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + total, result = await pilot_db.search_pilots([], [condition], []) + assert total == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + total, result = await pilot_db.search_pilots([], [condition1, condition2], []) + assert total == 0 + assert not result + + +async def test_search_sorts(populated_pilot_db): + """Test that we can search for pilots in the database and sort the results.""" + async with populated_pilot_db as pilot_db: + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort]) + assert total == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + total, result = await pilot_db.search_pilots([], [], [sort1, sort2]) + assert total == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + + +@pytest.mark.parametrize( + "per_page, page, expected_len, expected_first_id, expect_exception", + [ + (10, 1, 10, 1, None), # Page 1 + (10, 2, 10, 11, None), # Page 2 + (10, 10, 10, 91, None), # Page 10 + (50, 2, 50, 51, None), # Page 2 with 50 per page + (10, 11, 0, None, None), # Page beyond range, should return empty + (10, 0, None, None, InvalidQueryError), # Invalid page + (0, 1, None, None, InvalidQueryError), # Invalid per_page + ], +) +async def test_search_pagination( + populated_pilot_db, + per_page, + page, + expected_len, + expected_first_id, + expect_exception, +): + """Test pagination logic in pilot search.""" + async with populated_pilot_db as pilot_db: + if expect_exception: + with pytest.raises(expect_exception): + await pilot_db.search_pilots([], [], [], per_page=per_page, page=page) + else: + total, result = await pilot_db.search_pilots( + [], [], [], per_page=per_page, page=page + ) + assert total == N + if expected_len == 0: + assert not result + else: + assert result + assert len(result) == expected_len + assert result[0]["PilotID"] == expected_first_id diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py new file mode 100644 index 000000000..c7f2e2908 --- /dev/null +++ b/diracx-db/tests/pilots/utils.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from datetime import datetime, timezone +from typing import Any + +import pytest +from sqlalchemy import update + +from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, VectorSearchOperator, VectorSearchSpec +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +MAIN_VO = "lhcb" +N = 100 + +# ------------ Fetching data ------------ + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] +) -> list[dict[Any, Any]]: + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=1000, + ) + + return pilots + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=10000, + ) + + return [job["JobID"] for job in jobs] + + +# ------------ Creating data ------------ + + +@pytest.fixture +async def add_stamps(pilot_db): + async def _add_stamps(start_n=0): + async with pilot_db as db: + # Add pilots + refs = [f"ref_{i}" for i in range(start_n, start_n + N)] + stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] + pilot_references = dict(zip(stamps, refs)) + + vo = MAIN_VO + + await db.add_pilots( + stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + ) + + return await get_pilots_by_stamp(db, stamps) + + return _add_stamps + + +@pytest.fixture +async def create_timed_pilots(pilot_db, add_stamps): + async def _create_timed_pilots( + old_date: datetime, aborted: bool = False, start_n=0 + ): + # Get pilots + pilots = await add_stamps(start_n) + + async with pilot_db as db: + # Update manually their age + # Collect PilotStamps + pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] + + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .values(SubmissionTime=old_date) + ) + + if aborted: + stmt = stmt.values(Status="Aborted") + + res = await db.conn.execute(stmt) + assert res.rowcount == len(pilot_stamps) + + pilots = await get_pilots_by_stamp(db, pilot_stamps) + return pilots + + return _create_timed_pilots + + +@pytest.fixture +async def create_old_pilots_environment(pilot_db, create_timed_pilots): + non_aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), False, N + ) + aborted_recent = await create_timed_pilots( + datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N + ) + + aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N + ) + non_aborted_very_old = await create_timed_pilots( + datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N + ) + + pilot_number = 4 * N + + assert pilot_number == ( + len(non_aborted_recent) + + len(aborted_recent) + + len(aborted_very_old) + + len(non_aborted_very_old) + ) + + # Phase 0. Verify that we have the right environment + async with pilot_db as pilot_db: + # Ensure that we can get every pilot (only get first of each group) + await get_pilots_by_stamp(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_recent[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [aborted_very_old[0]["PilotStamp"]]) + await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) + + return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old diff --git a/diracx-db/tests/test_dummy_db.py b/diracx-db/tests/test_dummy_db.py index f94eda5b7..8e324a28e 100644 --- a/diracx-db/tests/test_dummy_db.py +++ b/diracx-db/tests/test_dummy_db.py @@ -149,6 +149,7 @@ async def test_failed_transaction(dummy_db): assert result # This will raise an exception and the transaction will be rolled back + result = await dummy_db.summary(["unexistingfieldraisinganerror"], []) assert result[0]["count"] == 10 diff --git a/diracx-logic/src/diracx/logic/pilots/__init__.py b/diracx-logic/src/diracx/logic/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py new file mode 100644 index 000000000..9b9ce9f9f --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -0,0 +1,122 @@ +from __future__ import annotations + +from datetime import datetime, timedelta, timezone + +from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError +from diracx.core.models.pilot import PilotFieldsMapping +from diracx.db.sql import PilotAgentsDB + +from .query import ( + get_outdated_pilots, + get_pilot_ids_by_stamps, + get_pilot_jobs_ids_by_pilot_id, + get_pilots_by_stamp, +) + + +async def register_new_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + vo: str, + grid_type: str, + grid_site: str, + destination_site: str, + status: str, + pilot_job_references: dict[str, str] | None, +): + # [IMPORTANT] Check unicity of pilot stamps + # If a pilot already exists, we raise an error (transaction will rollback) + existing_pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, pilot_stamps=pilot_stamps + ) + + # If we found pilots from the list, this means some pilots already exists + if len(existing_pilots) > 0: + found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} + + raise PilotAlreadyExistsError( + f"The following pilots already exist: {found_keys}" + ) + + await pilot_db.add_pilots( + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_references=pilot_job_references, + status=status, + ) + + +async def delete_pilots( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str] | None = None, + age_in_days: int | None = None, + delete_only_aborted: bool = True, + vo_constraint: str | None = None, +): + if pilot_stamps: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=pilot_stamps, allow_missing=True + ) + else: + assert age_in_days + assert vo_constraint + + cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + + pilots = await get_outdated_pilots( + pilot_db=pilot_db, + cutoff_date=cutoff_date, + only_aborted=delete_only_aborted, + parameters=["PilotID"], + vo_constraint=vo_constraint, + ) + + pilot_ids = [pilot["PilotID"] for pilot in pilots] + + await pilot_db.remove_jobs_from_pilots(pilot_ids) + await pilot_db.delete_pilot_logs(pilot_ids) + await pilot_db.delete_pilots(pilot_ids) + + +async def update_pilots_fields( + pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] +): + await pilot_db.update_pilot_fields(pilot_stamps_to_fields_mapping) + + +async def add_jobs_to_pilot( + pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int] +): + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + + now = datetime.now(tz=timezone.utc) + + # Prepare the list of dictionaries for bulk insertion + job_to_pilot_mapping = [ + {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in job_ids + ] + + await pilot_db.add_jobs_to_pilot( + job_to_pilot_mapping=job_to_pilot_mapping, + ) + + +async def get_pilot_jobs_ids_by_stamp( + pilot_db: PilotAgentsDB, pilot_stamp: str +) -> list[int]: + """Fetch pilot jobs by stamp.""" + try: + pilot_ids = await get_pilot_ids_by_stamps( + pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + ) + pilot_id = pilot_ids[0] + except PilotNotFoundError: + return [] + + return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py new file mode 100644 index 000000000..7487e0bfc --- /dev/null +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -0,0 +1,183 @@ +from __future__ import annotations + +from datetime import datetime +from typing import Any + +from diracx.core.exceptions import PilotNotFoundError +from diracx.core.models.pilot import PilotStatus +from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, SearchParams, SearchSpec, SummaryParams, VectorSearchOperator, VectorSearchSpec +from diracx.db.sql import PilotAgentsDB + +MAX_PER_PAGE = 10000 + + +async def search( + pilot_db: PilotAgentsDB, + user_vo: str, + page: int = 1, + per_page: int = 100, + body: SearchParams | None = None, +) -> tuple[int, list[dict[str, Any]]]: + """Retrieve information about jobs.""" + # Apply a limit to per_page to prevent abuse of the API + if per_page > MAX_PER_PAGE: + per_page = MAX_PER_PAGE + + if body is None: + body = SearchParams() + + body.search.append( + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=user_vo + ) + ) + + total, pilots = await pilot_db.search_pilots( + body.parameters, + body.search, + body.sort, + distinct=body.distinct, + page=page, + per_page=per_page, + ) + + return total, pilots + + +async def get_pilots_by_stamp( + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + parameters: list[str] = [], + allow_missing: bool = True, +) -> list[dict[Any, Any]]: + """Get pilots by their stamp. + + If `allow_missing` is set to False, if a pilot is missing, PilotNotFoundError will be raised. + """ + if parameters: + parameters.append("PilotStamp") + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, + search=[ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=pilot_stamps, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + # allow_missing is set as True by default to mark explicitly when we allow or not + if not allow_missing: + # Custom handling, to see which pilot_stamp does not exist (if so, say which one) + found_keys = {row["PilotStamp"] for row in pilots} + missing = set(pilot_stamps) - found_keys + + if missing: + raise PilotNotFoundError( + detail=str(missing), + ) + + return pilots + + +async def get_pilot_ids_by_stamps( + pilot_db: PilotAgentsDB, pilot_stamps: list[str], allow_missing=False +) -> list[int]: + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["PilotID"], + allow_missing=allow_missing, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_pilot_jobs_ids_by_pilot_id( + pilot_db: PilotAgentsDB, pilot_id: int +) -> list[int]: + _, jobs = await pilot_db.search_pilot_to_job_mapping( + parameters=["JobID"], + search=[ + ScalarSearchSpec( + parameter="PilotID", + operator=ScalarSearchOperator.EQUAL, + value=pilot_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [job["JobID"] for job in jobs] + + +async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[int]: + _, pilots = await pilot_db.search_pilot_to_job_mapping( + parameters=["PilotID"], + search=[ + ScalarSearchSpec( + parameter="JobID", + operator=ScalarSearchOperator.EQUAL, + value=job_id, + ) + ], + sorts=[], + distinct=True, + per_page=MAX_PER_PAGE, + ) + + return [pilot["PilotID"] for pilot in pilots] + + +async def get_outdated_pilots( + pilot_db: PilotAgentsDB, + cutoff_date: datetime, + vo_constraint: str, + only_aborted: bool = True, + parameters: list[str] = [], +): + query: list[SearchSpec] = [ + ScalarSearchSpec( + parameter="SubmissionTime", + operator=ScalarSearchOperator.LESS_THAN, + value=cutoff_date, + ), + # Add VO to avoid deleting other VO's pilots + ScalarSearchSpec( + parameter="VO", operator=ScalarSearchOperator.EQUAL, value=vo_constraint + ), + ] + + if only_aborted: + query.append( + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ) + + _, pilots = await pilot_db.search_pilots( + parameters=parameters, search=query, sorts=[] + ) + + return pilots + + +async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): + """Show information suitable for plotting.""" + body.search.append( + { + "parameter": "VO", + "operator": ScalarSearchOperator.EQUAL, + "value": vo, + } + ) + return await pilot_db.pilot_summary(body.grouping, body.search) diff --git a/diracx-routers/pyproject.toml b/diracx-routers/pyproject.toml index 6cdb993b6..b911e0fa1 100644 --- a/diracx-routers/pyproject.toml +++ b/diracx-routers/pyproject.toml @@ -6,11 +6,11 @@ requires-python = ">=3.11" keywords = [] license = { text = "GPL-3.0-only" } classifiers = [ - "Intended Audience :: Science/Research", - "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", - "Programming Language :: Python :: 3", - "Topic :: Scientific/Engineering", - "Topic :: System :: Distributed Computing", + "Intended Audience :: Science/Research", + "License :: OSI Approved :: GNU General Public License v3 (GPLv3)", + "Programming Language :: Python :: 3", + "Topic :: Scientific/Engineering", + "Topic :: System :: Distributed Computing", ] dependencies = [ "cachetools", @@ -34,12 +34,14 @@ dependencies = [ dynamic = ["version"] [project.optional-dependencies] -testing = ["diracx-testing", "moto[server]", "pytest-httpx", "freezegun", "pyjwt"] -types = [ - "types-cachetools", - "types-python-dateutil", - "types-PyYAML", +testing = [ + "diracx-testing", + "moto[server]", + "pytest-httpx", + "freezegun", + "pyjwt", ] +types = ["types-cachetools", "types-python-dateutil", "types-PyYAML"] [project.entry-points."diracx.services"] ".well-known" = "diracx.routers.auth.well_known:router" @@ -47,10 +49,12 @@ auth = "diracx.routers.auth:router" config = "diracx.routers.configuration:router" health = "diracx.routers.health:router" jobs = "diracx.routers.jobs:router" +pilots = "diracx.routers.pilots:router" [project.entry-points."diracx.access_policies"] wms = "diracx.routers.jobs.access_policies:WMSAccessPolicy" sandbox = "diracx.routers.jobs.access_policies:SandboxAccessPolicy" +pilot = "diracx.routers.pilots.access_policies:PilotManagementAccessPolicy" # Minimum version of the client supported [project.entry-points."diracx.min_client_version"] @@ -75,16 +79,16 @@ packages = ["src/diracx"] [tool.pytest.ini_options] testpaths = ["tests"] addopts = [ - "-v", - "--cov=diracx.routers", - "--cov-report=term-missing", - "-pdiracx.testing", - "-pdiracx.testing.osdb", - "--import-mode=importlib", + "-v", + "--cov=diracx.routers", + "--cov-report=term-missing", + "-pdiracx.testing", + "-pdiracx.testing.osdb", + "--import-mode=importlib", ] asyncio_mode = "auto" markers = [ - "enabled_dependencies: List of dependencies which should be available to the FastAPI test client", + "enabled_dependencies: List of dependencies which should be available to the FastAPI test client", ] diff --git a/diracx-routers/src/diracx/routers/pilots/__init__.py b/diracx-routers/src/diracx/routers/pilots/__init__.py new file mode 100644 index 000000000..03f9b8422 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/__init__.py @@ -0,0 +1,13 @@ +from __future__ import annotations + +import logging + +from ..fastapi_classes import DiracxRouter +from .management import router as management_router +from .query import router as query_router + +logger = logging.getLogger(__name__) + +router = DiracxRouter() +router.include_router(management_router) +router.include_router(query_router) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py new file mode 100644 index 000000000..011633d9b --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -0,0 +1,125 @@ +from __future__ import annotations + +from collections.abc import Callable +from enum import StrEnum, auto +from typing import Annotated + +from fastapi import Depends, HTTPException, status + +from diracx.core.models.search import VectorSearchOperator, VectorSearchSpec +from diracx.core.properties import GENERIC_PILOT, SERVICE_ADMINISTRATOR +from diracx.db.sql.job.db import JobDB +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.logic.pilots.query import get_pilots_by_stamp +from diracx.routers.access_policies import BaseAccessPolicy +from diracx.routers.utils.users import AuthorizedUserInfo + + +class ActionType(StrEnum): + # Change some pilot fields + MANAGE_PILOTS = auto() + # Read some pilot info + READ_PILOT_FIELDS = auto() + + +class PilotManagementAccessPolicy(BaseAccessPolicy): + """Rules: + * Every user can access data about his VO + * An administrator can modify a pilot. + """ + + @staticmethod + async def policy( + policy_name: str, + user_info: AuthorizedUserInfo, + /, + *, + action: ActionType | None = None, + pilot_db: PilotAgentsDB | None = None, + pilot_stamps: list[str] | None = None, + job_db: JobDB | None = None, + job_ids: list[int] | None = None, + allow_legacy_pilots: bool = False, + ): + assert action, "action is a mandatory parameter" + + # Users can query + # NOTE: Add into queries a VO constraint + # To manage pilots, user have to be an admin + # In some special cases (described with allow_legacy_pilots), we can allow pilots + if action == ActionType.MANAGE_PILOTS: + # To make it clear, we separate + is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties + is_a_pilot_if_allowed = ( + allow_legacy_pilots and GENERIC_PILOT in user_info.properties + ) + + if not is_an_admin and not is_a_pilot_if_allowed: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the permission to manage pilots.", + ) + + if action == ActionType.READ_PILOT_FIELDS: + if GENERIC_PILOT in user_info.properties: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Pilots can't read other pilots info.", + ) + + # + # Additional checks if job_ids or pilot_stamps are provided + # + + # First, if job_ids are provided, we check who is the owner + if job_db and job_ids: + job_owners = await job_db.summary( + ["Owner", "VO"], + [ + VectorSearchSpec( + parameter="JobID", + operator=VectorSearchOperator.IN, + values=job_ids, + ) + ], + ) + + expected_owner = { + "Owner": user_info.preferred_username, + "VO": user_info.vo, + "count": len(set(job_ids)), + } + # All the jobs belong to the user doing the query + # and all of them are present + if not job_owners == [expected_owner]: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have the rights to modify a pilot.", + ) + + # This is for example when we submit pilots, we use the user VO, so no need to verify + if pilot_db and pilot_stamps: + # Else, check its VO + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + parameters=["VO"], + allow_missing=True, + ) + + if len(pilots) != len(pilot_stamps): + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="At least one pilot does not exist.", + ) + + if not all(pilot["VO"] == user_info.vo for pilot in pilots): + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You don't have access to all pilots.", + ) + + +CheckPilotManagementPolicyCallable = Annotated[ + Callable, Depends(PilotManagementAccessPolicy.check) +] diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py new file mode 100644 index 000000000..48c604b96 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -0,0 +1,257 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated + +from fastapi import Body, Depends, HTTPException, Query, status + +from diracx.core.exceptions import ( + PilotAlreadyExistsError, +) +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.properties import GENERIC_PILOT +from diracx.db.sql import JobDB, PilotAgentsDB +from diracx.logic.pilots.management import ( + delete_pilots as delete_pilots_bl, +) +from diracx.logic.pilots.management import ( + get_pilot_jobs_ids_by_stamp, + register_new_pilots, + update_pilots_fields, +) +from diracx.logic.pilots.query import get_pilot_ids_by_job_id +from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token + +from ..fastapi_classes import DiracxRouter +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + + +@router.post("/") +async def add_pilot_stamps( + pilot_db: PilotAgentsDB, + pilot_stamps: Annotated[ + list[str], + Body(description="List of the pilot stamps we want to add to the db."), + ], + vo: Annotated[str, Body(description="Pilot virtual organization.")], + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + grid_type: Annotated[str, Body(description="Grid type of the pilots.")] = "Dirac", + grid_site: Annotated[str, Body(description="Pilots grid site.")] = "Unknown", + destination_site: Annotated[ + str, Body(description="Pilots destination site.") + ] = "NotAssigned", + pilot_references: Annotated[ + dict[str, str] | None, + Body(description="Association of a pilot reference with a pilot stamp."), + ] = None, + pilot_status: Annotated[ + PilotStatus, Body(description="Status of the pilots.") + ] = PilotStatus.SUBMITTED, +): + """Endpoint where a you can create pilots with their references. + + If a pilot stamp already exists, it will block the insertion. + """ + # TODO: Verify that grid types, sites, destination sites, etc. are valids + await check_permissions( + action=ActionType.MANAGE_PILOTS, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to create thousands of pilots at a time + # (It would be still able to create thousands of pilots, but slower) + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only create yourself.", + ) + + try: + await register_new_pilots( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_job_references=pilot_references, + status=pilot_status, + ) + except PilotAlreadyExistsError as e: + raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + +@router.delete("/", status_code=HTTPStatus.NO_CONTENT) +async def delete_pilots( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + pilot_stamps: Annotated[ + list[str] | None, Query(description="Stamps of the pilots we want to delete.") + ] = None, + age_in_days: Annotated[ + int | None, + Query( + description=( + "The number of days that define the maximum age of pilots to be deleted." + "Pilots older than this age will be considered for deletion." + ) + ), + ] = None, + delete_only_aborted: Annotated[ + bool, + Query( + description=( + "Flag indicating whether to only delete pilots whose status is 'Aborted'." + "If set to True, only pilots with the 'Aborted' status will be deleted." + "It is set by default as True to avoid any mistake." + "This flag is only used for deletion by time." + ) + ), + ] = False, +): + """Endpoint to delete a pilot. + + Two features: + + 1. Or you provide pilot_stamps, so you can delete pilots by their stamp + 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + + Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + """ + vo_constraint: str | None = None + + # If we delete by pilot_stamps, we check that we can access them + # Else, we add a constraint to the request, to avoid deleting pilots from another VO + if pilot_stamps: + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + ) + else: + vo_constraint = user_info.vo + + if not pilot_stamps and not age_in_days: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="pilot_stamps or age_in_days have to be provided.", + ) + + await delete_pilots_bl( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, + vo_constraint=vo_constraint, + ) + + +EXAMPLE_UPDATE_FIELDS = { + "Update the BenchMark field": { + "summary": "Update BenchMark", + "description": "Update only the BenchMark for one pilot.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0} + ] + }, + }, + "Update multiple statuses": { + "summary": "Update multiple pilots", + "description": "Update multiple pilots statuses.", + "value": { + "pilot_stamps_to_fields_mapping": [ + {"PilotStamp": "the_first_pilot_stamp", "Status": "Waiting"}, + {"PilotStamp": "the_second_pilot_stamp", "Status": "Waiting"}, + ] + }, + }, +} + + +@router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT) +async def update_pilot_fields( + pilot_stamps_to_fields_mapping: Annotated[ + list[PilotFieldsMapping], + Body( + description="(pilot_stamp, pilot_fields) mapping to change.", + embed=True, + openapi_examples=EXAMPLE_UPDATE_FIELDS, # type: ignore + ), + ], + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], +): + """Modify a field of a pilot. + + Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + """ + # Ensures stamps validity + pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + await check_permissions( + action=ActionType.MANAGE_PILOTS, + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + allow_legacy_pilots=True, # dirac-admin-add-pilot + ) + + # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time + # (It would be still able to modify thousands of pilots, but slower) + # We are not able to affirm that this pilot modifies itself + if GENERIC_PILOT in user_info.properties: + if len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="As a pilot, you can only modify yourself.", + ) + + await update_pilots_fields( + pilot_db=pilot_db, + pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, + ) + + +@router.get("/jobs") +async def get_pilot_jobs( + pilot_db: PilotAgentsDB, + job_db: JobDB, + check_permissions: CheckPilotManagementPolicyCallable, + pilot_stamp: Annotated[ + str | None, Query(description="The stamp of the pilot.") + ] = None, + job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, +) -> list[int]: + """Endpoint only for admins, to get jobs of a pilot.""" + if pilot_stamp: + # Check VO + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + ) + + return await get_pilot_jobs_ids_by_stamp( + pilot_db=pilot_db, + pilot_stamp=pilot_stamp, + ) + elif job_id: + # Check job owner + await check_permissions( + action=ActionType.READ_PILOT_FIELDS, job_db=job_db, job_ids=[job_id] + ) + + return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) + + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail="You must provide either pilot_stamp or job_id", + ) diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py new file mode 100644 index 000000000..29001e0c7 --- /dev/null +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from http import HTTPStatus +from typing import Annotated, Any + +from fastapi import Body, Depends, Response + +from diracx.core.models.search import SearchParams, SummaryParams +from diracx.db.sql import PilotAgentsDB +from diracx.logic.pilots.query import search as search_bl +from diracx.logic.pilots.query import summary as summary_bl + +from ..fastapi_classes import DiracxRouter +from ..utils.users import AuthorizedUserInfo, verify_dirac_access_token +from .access_policies import ( + ActionType, + CheckPilotManagementPolicyCallable, +) + +router = DiracxRouter() + +EXAMPLE_SEARCHES = { + "Show all": { + "summary": "Show all", + "description": "Shows all pilots the current user has access to.", + "value": {}, + }, + "A specific pilot": { + "summary": "A specific pilot", + "description": "Search for a specific pilot by ID", + "value": {"search": [{"parameter": "PilotID", "operator": "eq", "value": "5"}]}, + }, + "Get ordered pilot statuses": { + "summary": "Get ordered pilot statuses", + "description": "Get only pilot statuses for specific pilots, ordered by status", + "value": { + "parameters": ["PilotID", "Status"], + "search": [ + {"parameter": "PilotID", "operator": "in", "values": ["6", "2", "3"]} + ], + "sort": [{"parameter": "PilotID", "direction": "asc"}], + }, + }, +} + + +EXAMPLE_RESPONSES: dict[int | str, dict[str, Any]] = { + 200: { + "description": "List of matching results", + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, + 206: { + "description": "Partial Content. Only a part of the requested range could be served.", + "headers": { + "Content-Range": { + "description": "The range of pilots returned in this response", + "schema": {"type": "string", "example": "pilots 0-1/4"}, + } + }, + "model": list[dict[str, Any]], + "content": { + "application/json": { + "example": [ + { + "PilotID": 3, + "SubmissionTime": "2023-05-25T07:03:35.602654", + "LastUpdateTime": "2023-05-25T07:03:35.602656", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 1.0, + }, + { + "PilotID": 5, + "SubmissionTime": "2023-06-25T07:03:35.602654", + "LastUpdateTime": "2023-07-25T07:03:35.602652", + "Status": "RUNNING", + "GridType": "Dirac", + "BenchMark": 63.1, + }, + ] + } + }, + }, +} + + +@router.post("/search", responses=EXAMPLE_RESPONSES) +async def search( + pilot_db: PilotAgentsDB, + check_permissions: CheckPilotManagementPolicyCallable, + response: Response, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + page: int = 1, + per_page: int = 100, + body: Annotated[ + SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore + ] = None, +) -> list[dict[str, Any]]: + """Retrieve information about pilots.""" + # Inspired by /api/jobs/query + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + total, pilots = await search_bl( + pilot_db=pilot_db, + user_vo=user_info.vo, + page=page, + per_page=per_page, + body=body, + ) + + # Set the Content-Range header if needed + # https://datatracker.ietf.org/doc/html/rfc7233#section-4 + + # No pilots found but there are pilots for the requested search + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + if len(pilots) == 0 and total > 0: + response.headers["Content-Range"] = f"pilots */{total}" + response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE + + # The total number of pilots is greater than the number of pilots returned + # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 + elif len(pilots) < total: + first_idx = per_page * (page - 1) + last_idx = min(first_idx + len(pilots), total) - 1 if total > 0 else 0 + response.headers["Content-Range"] = f"pilots {first_idx}-{last_idx}/{total}" + response.status_code = HTTPStatus.PARTIAL_CONTENT + return pilots + + +@router.post("/summary") +async def summary( + pilot_db: PilotAgentsDB, + user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], + body: SummaryParams, + check_permissions: CheckPilotManagementPolicyCallable, +): + """Show information suitable for plotting.""" + await check_permissions(action=ActionType.READ_PILOT_FIELDS) + + return await summary_bl( + pilot_db=pilot_db, + body=body, + vo=user_info.vo, + ) diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py new file mode 100644 index 000000000..d0cea485e --- /dev/null +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -0,0 +1,281 @@ +from __future__ import annotations + +from datetime import datetime, timezone + +import pytest +from sqlalchemy import update + +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.db.sql import PilotAgentsDB +from diracx.db.sql.pilots.schema import PilotAgents + +pytestmark = pytest.mark.enabled_dependencies( + [ + "PilotCredentialsAccessPolicy", + "DevelopmentSettings", + "AuthDB", + "AuthSettings", + "ConfigSource", + "BaseAccessPolicy", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + "JobDB", + ] +) + +MAIN_VO = "lhcb" +N = 100 + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +async def test_create_pilots(normal_test_client): + # Lots of request, to validate that it returns the credentials in the same order as the input references + pilot_stamps = [f"stamps_{i}" for i in range(N)] + + # -------------- Bulk insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Register a pilot that already exists, and one that does not -------------- + + body = { + "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], + "vo": MAIN_VO, + } + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 409 + assert ( + r.json()["detail"] + == f"The following pilots already exist: {{'{pilot_stamps[0]}'}}" + ) + + # -------------- Register a pilot that does not exists **but** was called before in an error -------------- + # To prove that, if I tried to register a pilot that does not exist with one that already exists, + # i can normally add the one that did not exist before (it should not have added it before) + body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"], "vo": MAIN_VO} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + headers={ + "Content-Type": "application/json", + }, + ) + + assert r.status_code == 200 + + +async def test_create_pilot_and_delete_it(normal_test_client): + pilot_stamp = "stamps_1" + + # -------------- Insert -------------- + body = {"pilot_stamps": [pilot_stamp], "vo": MAIN_VO} + + # Create a pilot + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Duplicate -------------- + # Duplicate because it exists, should have 409 + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 409, r.json() + + # -------------- Delete -------------- + params = {"pilot_stamps": [pilot_stamp]} + + # We delete the pilot + r = normal_test_client.delete( + "/api/pilots/", + params=params, + ) + + assert r.status_code == 204 + + # -------------- Insert -------------- + # Create a the same pilot, but works because it does not exist anymore + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + +async def test_create_pilot_and_modify_it(normal_test_client): + pilot_stamps = ["stamps_1", "stamp_2"] + + # -------------- Insert -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + + # Create pilots + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + # -------------- Modify -------------- + # We modify only the first pilot + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamps[0], + BenchMark=1.0, + StatusReason="NewReason", + AccountingSent=True, + Status=PilotStatus.WAITING, + ).model_dump(exclude_unset=True) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + body = { + "parameters": [], + "search": [], + "sort": [], + "distinct": True, + } + + r = normal_test_client.post("/api/pilots/search", json=body) + assert r.status_code == 200, r.json() + pilot1 = r.json()[0] + pilot2 = r.json()[1] + + assert pilot1["BenchMark"] == 1.0 + assert pilot1["StatusReason"] == "NewReason" + assert pilot1["AccountingSent"] + assert pilot1["Status"] == PilotStatus.WAITING + + assert pilot2["BenchMark"] != pilot1["BenchMark"] + assert pilot2["StatusReason"] != pilot1["StatusReason"] + assert pilot2["AccountingSent"] != pilot1["AccountingSent"] + assert pilot2["Status"] != pilot1["Status"] + + +@pytest.mark.asyncio +async def test_delete_pilots_by_age_and_stamp(normal_test_client): + # Generate 100 pilot stamps + pilot_stamps = [f"stamp_{i}" for i in range(100)] + + # -------------- Insert all pilots -------------- + body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} + r = normal_test_client.post("/api/pilots/", json=body) + assert r.status_code == 200, r.json() + + # -------------- Modify last 50 pilots' fields -------------- + to_modify = pilot_stamps[50:] + mappings = [] + for idx, stamp in enumerate(to_modify): + # First 25 of modified set to ABORTED, others to WAITING + status = PilotStatus.ABORTED if idx < 25 else PilotStatus.WAITING + mapping = PilotFieldsMapping( + PilotStamp=stamp, + BenchMark=idx + 0.1, + StatusReason=f"Reason_{idx}", + AccountingSent=(idx % 2 == 0), + Status=status, + ).model_dump(exclude_unset=True) + mappings.append(mapping) + + r = normal_test_client.patch( + "/api/pilots/metadata", + json={"pilot_stamps_to_fields_mapping": mappings}, + ) + assert r.status_code == 204 + + # -------------- Directly set SubmissionTime to March 14, 2003 for last 50 -------------- + old_date = datetime(2003, 3, 14, tzinfo=timezone.utc) + # Access DB session from normal_test_client fixtures + db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + + async with db: + stmt = ( + update(PilotAgents) + .where(PilotAgents.pilot_stamp.in_(to_modify)) + .values(SubmissionTime=old_date) + ) + await db.conn.execute(stmt) + await db.conn.commit() + + # -------------- Verify all 100 pilots exist -------------- + search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200, r.json() + assert len(r.json()) == 100 + + # -------------- 1) Delete only old aborted pilots (25 expected) -------------- + # age_in_days large enough to include 2003-03-14 + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15, "delete_only_aborted": True}, + ) + assert r.status_code == 204 + # Expect 75 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 75 + + # -------------- 2) Delete all old pilots (remaining 25 old) -------------- + r = normal_test_client.delete( + "/api/pilots/", + params={"age_in_days": 15}, + ) + assert r.status_code == 204 + + # Expect 50 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 50 + + # -------------- 3) Delete one recent pilot by stamp -------------- + one_stamp = pilot_stamps[10] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) + assert r.status_code == 204 + # Expect 49 remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert len(r.json()) == 49 + + # -------------- 4) Delete all remaining pilots -------------- + # Collect remaining stamps + remaining = [p["PilotStamp"] for p in r.json()] + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) + assert r.status_code == 204 + # Expect none remaining + r = normal_test_client.post("/api/pilots/search", json=search_body) + assert r.status_code == 200 + assert len(r.json()) == 0 + + # -------------- 5) Attempt deleting unknown pilot, expect 400 -------------- + r = normal_test_client.delete( + "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} + ) + assert r.status_code == 204 diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py new file mode 100644 index 000000000..eeea8423e --- /dev/null +++ b/diracx-routers/tests/pilots/test_query.py @@ -0,0 +1,406 @@ +"""Inspired by pilots and jobs db search tests.""" + +from __future__ import annotations + +import pytest +from fastapi.testclient import TestClient + +from diracx.core.exceptions import InvalidQueryError +from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, SortDirection, SortSpec, VectorSearchOperator, VectorSearchSpec + +pytestmark = pytest.mark.enabled_dependencies( + [ + "AuthSettings", + "ConfigSource", + "DevelopmentSettings", + "PilotAgentsDB", + "PilotManagementAccessPolicy", + ] +) + + +@pytest.fixture +def normal_test_client(client_factory): + with client_factory.normal_user() as client: + yield client + + +MAIN_VO = "lhcb" +N = 100 + +PILOT_REASONS = [ + "I was sick", + "I can't, I have a pony.", + "I was shopping", + "I was sleeping", +] + +PILOT_STATUSES = list(PilotStatus) + + +@pytest.fixture +async def populated_pilot_client(normal_test_client): + pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)] + + # -------------- Bulk insert -------------- + body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} + + r = normal_test_client.post( + "/api/pilots/", + json=body, + ) + + assert r.status_code == 200, r.json() + + body = { + "pilot_stamps_to_fields_mapping": [ + PilotFieldsMapping( + PilotStamp=pilot_stamp, + BenchMark=i**2, + StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], + AccountingSent=True, + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + CurrentJobID=i, + Queue=f"queue_{i}", + ).model_dump(exclude_unset=True) + for i, pilot_stamp in enumerate(pilot_stamps) + ] + } + + r = normal_test_client.patch("/api/pilots/metadata", json=body) + + assert r.status_code == 204 + + yield normal_test_client + + +async def test_pilot_summary(populated_pilot_client: TestClient): + # Group by StatusReason + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["StatusReason"], + }, + ) + + assert r.status_code == 200 + + assert sum([el["count"] for el in r.json()]) == N + assert len(r.json()) == len(PILOT_REASONS) + + # Group by CurrentJobID + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + }, + ) + + assert r.status_code == 200 + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == N + + # Group by CurrentJobID where BenchMark < 10^2 + r = populated_pilot_client.post( + "/api/pilots/summary", + json={ + "grouping": ["CurrentJobID"], + "search": [{"parameter": "BenchMark", "operator": "lt", "value": 10**2}], + }, + ) + + assert r.status_code == 200, r.json() + + assert all(el["count"] == 1 for el in r.json()) + assert len(r.json()) == 10 + + +@pytest.fixture +async def search(populated_pilot_client): + async def _search( + parameters, conditions, sorts, distinct=False, page=1, per_page=100 + ): + body = { + "parameters": parameters, + "search": conditions, + "sort": sorts, + "distinct": distinct, + } + + params = {"per_page": per_page, "page": page} + + r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) + + if r.status_code == 400: + # If we have a status_code 400, that means that the query failed + raise InvalidQueryError() + + return r.json(), r.headers + + return _search + + +async def test_search_parameters(search): + """Test that we can search specific parameters for pilots.""" + # Search a specific parameter: PilotID + result, headers = await search(["PilotID"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID"} + assert "Content-Range" not in headers + + # Search a specific parameter: Status + result, headers = await search(["Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"Status"} + assert "Content-Range" not in headers + + # Search for multiple parameters: PilotID, Status + result, headers = await search(["PilotID", "Status"], [], []) + assert len(result) == N + assert result + for r in result: + assert r.keys() == {"PilotID", "Status"} + assert "Content-Range" not in headers + + # Search for a specific parameter but use distinct: Status + result, headers = await search(["Status"], [], [], distinct=True) + assert len(result) == len(PILOT_STATUSES) + assert result + assert "Content-Range" not in headers + + # Search for a non-existent parameter: Dummy + with pytest.raises(InvalidQueryError): + result, headers = await search(["Dummy"], [], []) + + +async def test_search_conditions(search): + """Test that we can search for specific pilots.""" + # Search a specific scalar condition: PilotID eq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 1 + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 3 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID lt 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert result[0]["PilotID"] == 1 + assert result[1]["PilotID"] == 2 + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID neq 3 + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 + ) + result, headers = await search([], [condition], []) + assert len(result) == 99 + assert result + assert len(result) == 99 + assert all(r["PilotID"] != 3 for r in result) + assert "Content-Range" not in headers + + # Search a specific scalar condition: PilotID eq 5873 (does not exist) + condition = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + ) + result, headers = await search([], [condition], []) + assert not result + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 3 + assert result + assert len(result) == 3 + assert all(r["PilotID"] in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] + ) + result, headers = await search([], [condition], []) + assert len(result) == 2 + assert result + assert len(result) == 2 + assert all(r["PilotID"] in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,3 + condition = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] + ) + result, headers = await search([], [condition], []) + assert len(result) == 97 + assert result + assert len(result) == 97 + assert all(r["PilotID"] not in [1, 2, 3] for r in result) + assert "Content-Range" not in headers + + # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) + condition = VectorSearchSpec( + parameter="PilotID", + operator=VectorSearchOperator.NOT_IN, + values=[1, 2, 5873], + ) + result, headers = await search([], [condition], []) + assert len(result) == 98 + assert result + assert len(result) == 98 + assert all(r["PilotID"] not in [1, 2] for r in result) + assert "Content-Range" not in headers + + # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + + assert result + assert len(result) == 1 + assert result[0]["PilotID"] == 5 + assert result[0]["PilotStamp"] == "stamp_5" + assert "Content-Range" not in headers + + # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 + condition1 = ScalarSearchSpec( + parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + ) + condition2 = VectorSearchSpec( + parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + ) + result, headers = await search([], [condition1, condition2], []) + assert len(result) == 0 + assert not result + assert "Content-Range" not in headers + + +async def test_search_sorts(search): + """Test that we can search for pilots and sort the results.""" + # Search and sort by PilotID in ascending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == i + 1 + assert "Content-Range" not in headers + + # Search and sort by PilotID in descending order + sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + for i, r in enumerate(result): + assert r["PilotID"] == N - i + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[2]["PilotStamp"] == "stamp_100" + assert result[12]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in descending order + sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) + result, headers = await search([], [], [sort]) + assert len(result) == N + assert result + # Assert that stamp_10 is before stamp_2 because of the lexicographical order + assert result[97]["PilotStamp"] == "stamp_100" + assert result[87]["PilotStamp"] == "stamp_2" + assert "Content-Range" not in headers + + # Search and sort by PilotStamp in ascending order and PilotID in descending order + sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) + sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) + result, headers = await search([], [], [sort1, sort2]) + assert len(result) == N + assert result + assert result[0]["PilotStamp"] == "stamp_1" + assert result[0]["PilotID"] == 1 + assert result[99]["PilotStamp"] == "stamp_99" + assert result[99]["PilotID"] == 99 + assert "Content-Range" not in headers + + +async def test_search_pagination(search): + """Test that we can search for pilots.""" + # Search for the first 10 pilots + result, headers = await search([], [], [], per_page=10, page=1) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 1 + + # Search for the second 10 pilots + result, headers = await search([], [], [], per_page=10, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert total == N + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 11 + + # Search for the last 10 pilots + result, headers = await search([], [], [], per_page=10, page=10) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 10 + assert result[0]["PilotID"] == 91 + + # Search for the second 50 pilots + result, headers = await search([], [], [], per_page=50, page=2) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert result + assert len(result) == 50 + assert result[0]["PilotID"] == 51 + + # Invalid page number + result, headers = await search([], [], [], per_page=10, page=11) + assert "Content-Range" in headers + # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" + total = int(headers["Content-Range"].split("/")[1]) + assert not result + + # Invalid page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=10, page=0) + + # Invalid per_page number + with pytest.raises(InvalidQueryError): + result = await search([], [], [], per_page=0, page=1) diff --git a/docs/dev/explanations/pilots.md b/docs/dev/explanations/pilots.md new file mode 100644 index 000000000..544e12db5 --- /dev/null +++ b/docs/dev/explanations/pilots.md @@ -0,0 +1,20 @@ +## Presentation + +Pilots are a piece of software that is running on *worker nodes*. There are two types of pilots: "DIRAC pilots", and "DiracX pilots". The first type corresponds to pilots with proxies, sent by DIRAC; and the second type corresponds to pilots with secrets. Both kinds will eventually interact with DiracX using tokens (DIRAC pilots by exchanging their proxies for tokens, DiracX by exchanging their secrets for tokens). + +## Management + +Their management is adapted in DiracX, and each feature has its own route in DiracX. We will split the `/pilots` route into two parts: + +1. `/api/pilots/*` to allow administrators and users to access and modify pilots +2. `/api/pilots/internal/*` is allocated for pilots resources: only DiracX pilots will have access to these resources + +Each part has its own security policy: we want to prevent pilots to access users resources and vice-versa. To differentiate DIRAC pilots from users, we can get their token and compare their properties: `GENERIC_PILOT` is the property that defines a pilot. For DiracX pilots, we can differentiate them by looking at the token structure: they don't have properties, but a "stamp" (their identifier). + +## Endpoints + +We ordered our endpoints like so: + +1. Creation: `POST /api/pilots/` +2. Deletion: `DELETE /api/pilots/` +3. Modification: `PATCH /api/pilots/metadata` diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index b97d2e439..25803aefc 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -16,6 +16,8 @@ BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, + BodyPilotsAddPilotStamps, + BodyPilotsUpdatePilotFields, ExtendedMetadata, GroupInfo, HTTPValidationError, @@ -26,6 +28,7 @@ JobMetaData, JobStatusUpdate, OpenIDConfiguration, + PilotFieldsMapping, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -48,6 +51,7 @@ from ._enums import ( # type: ignore ChecksumAlgorithm, JobStatus, + PilotStatus, SandboxFormat, SandboxType, ScalarSearchOperator, @@ -63,6 +67,8 @@ "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", + "BodyPilotsAddPilotStamps", + "BodyPilotsUpdatePilotFields", "ExtendedMetadata", "GroupInfo", "HTTPValidationError", @@ -73,6 +79,7 @@ "JobMetaData", "JobStatusUpdate", "OpenIDConfiguration", + "PilotFieldsMapping", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", @@ -92,6 +99,7 @@ "VectorSearchSpec", "ChecksumAlgorithm", "JobStatus", + "PilotStatus", "SandboxFormat", "SandboxType", "ScalarSearchOperator", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py index b83473639..849d3252a 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_enums.py @@ -34,6 +34,19 @@ class JobStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): RESCHEDULED = "Rescheduled" +class PilotStatus(str, Enum, metaclass=CaseInsensitiveEnumMeta): + """PilotStatus.""" + + SUBMITTED = "Submitted" + WAITING = "Waiting" + RUNNING = "Running" + DONE = "Done" + FAILED = "Failed" + DELETED = "Deleted" + ABORTED = "Aborted" + UNKNOWN = "Unknown" + + class SandboxFormat(str, Enum, metaclass=CaseInsensitiveEnumMeta): """SandboxFormat.""" diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 69b8ffcf1..946623cbb 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -184,6 +184,109 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids +class BodyPilotsAddPilotStamps(_serialization.Model): + """Body_pilots_add_pilot_stamps. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :vartype pilot_stamps: list[str] + :ivar vo: Pilot virtual organization. Required. + :vartype vo: str + :ivar grid_type: Grid type of the pilots. + :vartype grid_type: str + :ivar grid_site: Pilots grid site. + :vartype grid_site: str + :ivar destination_site: Pilots destination site. + :vartype destination_site: str + :ivar pilot_references: Association of a pilot reference with a pilot stamp. + :vartype pilot_references: dict[str, str] + :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :vartype pilot_status: str or ~_generated.models.PilotStatus + """ + + _validation = { + "pilot_stamps": {"required": True}, + "vo": {"required": True}, + } + + _attribute_map = { + "pilot_stamps": {"key": "pilot_stamps", "type": "[str]"}, + "vo": {"key": "vo", "type": "str"}, + "grid_type": {"key": "grid_type", "type": "str"}, + "grid_site": {"key": "grid_site", "type": "str"}, + "destination_site": {"key": "destination_site", "type": "str"}, + "pilot_references": {"key": "pilot_references", "type": "{str}"}, + "pilot_status": {"key": "pilot_status", "type": "str"}, + } + + def __init__( + self, + *, + pilot_stamps: List[str], + vo: str, + grid_type: str = "Dirac", + grid_site: str = "Unknown", + destination_site: str = "NotAssigned", + pilot_references: Optional[Dict[str, str]] = None, + pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :paramtype pilot_stamps: list[str] + :keyword vo: Pilot virtual organization. Required. + :paramtype vo: str + :keyword grid_type: Grid type of the pilots. + :paramtype grid_type: str + :keyword grid_site: Pilots grid site. + :paramtype grid_site: str + :keyword destination_site: Pilots destination site. + :paramtype destination_site: str + :keyword pilot_references: Association of a pilot reference with a pilot stamp. + :paramtype pilot_references: dict[str, str] + :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype pilot_status: str or ~_generated.models.PilotStatus + """ + super().__init__(**kwargs) + self.pilot_stamps = pilot_stamps + self.vo = vo + self.grid_type = grid_type + self.grid_site = grid_site + self.destination_site = destination_site + self.pilot_references = pilot_references + self.pilot_status = pilot_status + + +class BodyPilotsUpdatePilotFields(_serialization.Model): + """Body_pilots_update_pilot_fields. + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. + :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + + _validation = { + "pilot_stamps_to_fields_mapping": {"required": True}, + } + + _attribute_map = { + "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + } + + def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + """ + :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. + Required. + :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + """ + super().__init__(**kwargs) + self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + + class ExtendedMetadata(_serialization.Model): """ExtendedMetadata. @@ -950,6 +1053,102 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported +class PilotFieldsMapping(_serialization.Model): + """All the fields that a user can modify on a Pilot (except PilotStamp). + + All required parameters must be populated in order to send to server. + + :ivar pilot_stamp: Pilotstamp. Required. + :vartype pilot_stamp: str + :ivar status_reason: Statusreason. + :vartype status_reason: str + :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :vartype status: str or ~_generated.models.PilotStatus + :ivar bench_mark: Benchmark. + :vartype bench_mark: float + :ivar destination_site: Destinationsite. + :vartype destination_site: str + :ivar queue: Queue. + :vartype queue: str + :ivar grid_site: Gridsite. + :vartype grid_site: str + :ivar grid_type: Gridtype. + :vartype grid_type: str + :ivar accounting_sent: Accountingsent. + :vartype accounting_sent: bool + :ivar current_job_id: Currentjobid. + :vartype current_job_id: int + """ + + _validation = { + "pilot_stamp": {"required": True}, + } + + _attribute_map = { + "pilot_stamp": {"key": "PilotStamp", "type": "str"}, + "status_reason": {"key": "StatusReason", "type": "str"}, + "status": {"key": "Status", "type": "str"}, + "bench_mark": {"key": "BenchMark", "type": "float"}, + "destination_site": {"key": "DestinationSite", "type": "str"}, + "queue": {"key": "Queue", "type": "str"}, + "grid_site": {"key": "GridSite", "type": "str"}, + "grid_type": {"key": "GridType", "type": "str"}, + "accounting_sent": {"key": "AccountingSent", "type": "bool"}, + "current_job_id": {"key": "CurrentJobID", "type": "int"}, + } + + def __init__( + self, + *, + pilot_stamp: str, + status_reason: Optional[str] = None, + status: Optional[Union[str, "_models.PilotStatus"]] = None, + bench_mark: Optional[float] = None, + destination_site: Optional[str] = None, + queue: Optional[str] = None, + grid_site: Optional[str] = None, + grid_type: Optional[str] = None, + accounting_sent: Optional[bool] = None, + current_job_id: Optional[int] = None, + **kwargs: Any + ) -> None: + """ + :keyword pilot_stamp: Pilotstamp. Required. + :paramtype pilot_stamp: str + :keyword status_reason: Statusreason. + :paramtype status_reason: str + :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", + "Failed", "Deleted", "Aborted", and "Unknown". + :paramtype status: str or ~_generated.models.PilotStatus + :keyword bench_mark: Benchmark. + :paramtype bench_mark: float + :keyword destination_site: Destinationsite. + :paramtype destination_site: str + :keyword queue: Queue. + :paramtype queue: str + :keyword grid_site: Gridsite. + :paramtype grid_site: str + :keyword grid_type: Gridtype. + :paramtype grid_type: str + :keyword accounting_sent: Accountingsent. + :paramtype accounting_sent: bool + :keyword current_job_id: Currentjobid. + :paramtype current_job_id: int + """ + super().__init__(**kwargs) + self.pilot_stamp = pilot_stamp + self.status_reason = status_reason + self.status = status + self.bench_mark = bench_mark + self.destination_site = destination_site + self.queue = queue + self.grid_site = grid_site + self.grid_type = grid_type + self.accounting_sent = accounting_sent + self.current_job_id = current_job_id + + class SandboxDownloadResponse(_serialization.Model): """SandboxDownloadResponse. diff --git a/pixi.lock b/pixi.lock index 8ae9e4d6b..c2ff7efa8 100644 --- a/pixi.lock +++ b/pixi.lock @@ -19241,8 +19241,8 @@ packages: requires_python: '>=3.11' - pypi: ./ name: diracx - version: 0.0.13.dev11+gbffb6c6d9.d20260414 - sha256: 410903a3be93f06d98b9df3cd204f3a92c585c3f4424516134f6b7272630536b + version: 0.0.13.dev10+g09d7149dd.d20260414 + sha256: 1f78b10647ef5e2e13a5438ef3c7f2ac2c051773bf180e30f00fdc45a328425f requires_dist: - diracx-api - diracx-cli @@ -19252,7 +19252,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-api name: diracx-api - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: fce056f16b4ca37c0b2847bf95cf7cf02d2f75b1bc63793efd3fc959dfbc0cb9 requires_dist: - diracx-client @@ -19263,7 +19263,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-cli name: diracx-cli - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: a9c02d48d01723886e3f95b1379cf844587bbca1cc354e629dc221266fbeef8c requires_dist: - diraccfg @@ -19280,7 +19280,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-client name: diracx-client - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: 3d974bce5bd5a086bb1e8e6263dbab10927dfc446d4ff44836433a18508bd727 requires_dist: - azure-core @@ -19292,7 +19292,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-core name: diracx-core - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: febb534a1a976612961f00cc690f64ff4f4f6ecef7a649fb1413602b3bb9f6fd requires_dist: - aiobotocore>=2.15 @@ -19331,7 +19331,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-db name: diracx-db - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: 55e3f8e41ada508c6d544766d0c8f2dfcc6d798a127e176e2050a4aa5cb8d228 requires_dist: - diracx-core @@ -19362,7 +19362,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-logic name: diracx-logic - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: e7e3d6391b7d5b6d4ce909ef500ac152b149c7c48f80802320a0baead3e9ba3e requires_dist: - cachetools @@ -19378,8 +19378,8 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-routers name: diracx-routers - version: 0.0.13.dev10+ge74cb7c97.d20260408 - sha256: 0077513713e84925ecc3c53259ecf985514b40dfe1d7688eace6e048bfe3b727 + version: 0.0.13.dev10+g09d7149dd.d20260414 + sha256: ef0c49134e20b3a5232131ec53931179466d3f8db3a16b846660ff39a4978acc requires_dist: - cachetools - diracx-core @@ -19409,8 +19409,8 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-tasks name: diracx-tasks - version: 0.0.13.dev10+ge74cb7c97.d20260408 - sha256: 47ecbf1d4db5442abf0cb47de03b0c0ce7064cfb0f9ff9bf009c9a2041d02db1 + version: 0.0.13.dev10+g09d7149dd.d20260414 + sha256: 752189cc698d17c76e8c240a4a82f6593a5dc040545c29a504753de135b1bd6b requires_dist: - croniter - diracx-core @@ -19427,7 +19427,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-testing name: diracx-testing - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: 9e7f8dc219ef9a81e9d5b14c0140344ee1ddabc65222a18f17faf5b5becce456 requires_dist: - httpx @@ -20091,14 +20091,14 @@ packages: timestamp: 1748320218212 - pypi: ./extensions/gubbins name: gubbins - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: 005a02b3df8d030f0ff43a321b2a5b7c177ecca059d331a8b93c77a100ad0ceb requires_dist: - gubbins-testing ; extra == 'testing' requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-api name: gubbins-api - version: 0.0.11.dev25+g109843191.d20260324 + version: 0.0.12.dev4+ge552651ac sha256: 31031bdd61bfe53d391e1650c0cba042fb300b6143f7aae66b5848f4ee3276af requires_dist: - diracx-api @@ -20107,7 +20107,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-cli name: gubbins-cli - version: 0.0.11.dev25+g109843191.d20260324 + version: 0.0.12.dev4+ge552651ac sha256: fbb038cf5c271ae96149106087c18c6ce27e5578d565d3de6d376d20707ffc8f requires_dist: - diracx-cli @@ -20117,7 +20117,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-client name: gubbins-client - version: 0.0.11.dev25+g109843191.d20260324 + version: 0.0.12.dev4+ge552651ac sha256: 7873beaff1c5895c83282bf16de842c97cc13c5033a3f2e039a84ff0aa9e7cec requires_dist: - diracx-client @@ -20126,7 +20126,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-core name: gubbins-core - version: 0.0.11.dev25+g109843191.d20260324 + version: 0.0.12.dev4+ge552651ac sha256: 27daa6103085f4e4438b5ece4b3422b224ffb9fe4b39c93aae7a5c4ae40884ec requires_dist: - diracx-core @@ -20137,7 +20137,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-db name: gubbins-db - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: 423f2a5336b71eee661db826234b2a1e22e0920da3cd66ab795372caef174d05 requires_dist: - diracx-db @@ -20146,7 +20146,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-logic name: gubbins-logic - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: ce4627c8c026fdbdbfaa11a8b1ff1cd4e206abb53983d2e13e2459ca8c5a9b69 requires_dist: - diracx-logic @@ -20160,7 +20160,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-routers name: gubbins-routers - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: ad9b00b4ea222fe7b5b2913455e9d817b8e72d8ce2ecd240f2732f4992c29951 requires_dist: - diracx-routers @@ -20177,7 +20177,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-tasks name: gubbins-tasks - version: 0.0.13.dev10+ge74cb7c97.d20260408 + version: 0.0.13.dev10+g09d7149dd.d20260414 sha256: c2b53c4c625ffc5745b191eb68a1df61a7fe89cfb2a12a31f17dc8985543c1e1 requires_dist: - diracx-tasks @@ -20189,7 +20189,7 @@ packages: requires_python: '>=3.11' - pypi: ./extensions/gubbins/gubbins-testing name: gubbins-testing - version: 0.0.11.dev25+g109843191.d20260324 + version: 0.0.12.dev4+ge552651ac sha256: 2247538bbb010522cc9675f0146807d553c7fafe71a219795ea5bb2a46dcf2e0 requires_dist: - diracx-testing From 257e486746f8cf688b7b4a61d6ff3cbd787c8909 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 10:04:48 +0200 Subject: [PATCH 2/9] fix: Add more security to the pilot creation router --- .../src/diracx/routers/pilots/management.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 48c604b96..7f892335a 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -9,7 +9,7 @@ PilotAlreadyExistsError, ) from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus -from diracx.core.properties import GENERIC_PILOT +from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR from diracx.db.sql import JobDB, PilotAgentsDB from diracx.logic.pilots.management import ( delete_pilots as delete_pilots_bl, @@ -69,10 +69,17 @@ async def add_pilot_stamps( if GENERIC_PILOT in user_info.properties: if len(pilot_stamps) != 1: raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, + status_code=status.HTTP_403_FORBIDDEN, detail="As a pilot, you can only create yourself.", ) + if JOB_ADMINISTRATOR not in user_info.properties: + if not vo == user_info.vo: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="You can create pilots only for your VO.", + ) + try: await register_new_pilots( pilot_db=pilot_db, From 04123348fe33f48072040ffa0142b2711ae76b11 Mon Sep 17 00:00:00 2001 From: Robin VAN DE MERGHEL Date: Tue, 5 Aug 2025 13:33:43 +0200 Subject: [PATCH 3/9] fix: Fixed a micro typo --- diracx-logic/src/diracx/logic/pilots/management.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 9b9ce9f9f..5698b447e 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -30,7 +30,7 @@ async def register_new_pilots( pilot_db=pilot_db, pilot_stamps=pilot_stamps ) - # If we found pilots from the list, this means some pilots already exists + # If we found pilots from the list, this means some pilots already exist if len(existing_pilots) > 0: found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} From c666b94cc266d88f466b392b3df7444e548b86e6 Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Tue, 24 Mar 2026 16:06:16 +0100 Subject: [PATCH 4/9] chore: regenerate client --- .../_generated/aio/operations/_operations.py | 17 +++++----- .../_generated/aio/operations/_patch.py | 10 ++++-- .../client/_generated/models/_models.py | 31 +++---------------- .../_generated/operations/_operations.py | 19 ++++++------ .../client/_generated/operations/_patch.py | 10 ++++-- diracx-core/src/diracx/core/models/job.py | 2 +- diracx-core/src/diracx/core/models/search.py | 3 +- diracx-db/src/diracx/db/sql/dummy/db.py | 2 +- diracx-db/src/diracx/db/sql/pilots/db.py | 9 ++---- diracx-db/src/diracx/db/sql/pilots/schema.py | 1 + diracx-db/src/diracx/db/sql/utils/__init__.py | 16 ---------- .../diracx/routers/pilots/access_policies.py | 3 +- .../client/_generated/models/_models.py | 31 +++---------------- 13 files changed, 52 insertions(+), 102 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index 0e46aee29..a295bdfc9 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -2451,7 +2451,7 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I async def delete_pilots( self, *, - pilot_stamps: Optional[List[str]] = None, + pilot_stamps: Optional[list[str]] = None, age_in_days: Optional[int] = None, delete_only_aborted: bool = False, **kwargs: Any @@ -2622,7 +2622,7 @@ async def update_pilot_fields( @distributed_trace_async async def get_pilot_jobs( self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any - ) -> List[int]: + ) -> list[int]: """Get Pilot Jobs. Endpoint only for admins, to get jobs of a pilot. @@ -2646,7 +2646,7 @@ async def get_pilot_jobs( _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} - cls: ClsType[List[int]] = kwargs.pop("cls", None) + cls: ClsType[list[int]] = kwargs.pop("cls", None) _request = build_pilots_get_pilot_jobs_request( pilot_stamp=pilot_stamp, @@ -2683,7 +2683,7 @@ async def search( per_page: int = 100, content_type: str = "application/json", **kwargs: Any - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search. Retrieve information about pilots. @@ -2711,7 +2711,7 @@ async def search( per_page: int = 100, content_type: str = "application/json", **kwargs: Any - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search. Retrieve information about pilots. @@ -2738,7 +2738,7 @@ async def search( page: int = 1, per_page: int = 100, **kwargs: Any - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search. Retrieve information about pilots. @@ -2765,9 +2765,10 @@ async def search( _params = kwargs.pop("params", {}) or {} content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + content_type = content_type if body else None + cls: ClsType[list[dict[str, Any]]] = kwargs.pop("cls", None) - content_type = content_type or "application/json" + content_type = content_type or "application/json" if body else None _json = None _content = None if isinstance(body, (IOBase, bytes)): diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index 0c70ce3e9..fa7f62e3b 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -6,17 +6,23 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ + from __future__ import annotations __all__ = [ "AuthOperations", "JobsOperations", - "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations -from ....patches.pilots.aio import PilotsOperations + +try: + from ....patches.pilots.aio import PilotsOperations + + __all__.append("PilotsOperations") +except ImportError: + pass def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index f592def1f..7c7730dc7 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -5,7 +5,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from collections.abc import MutableMapping import datetime from typing import Any, Optional, TYPE_CHECKING, Union @@ -13,7 +12,6 @@ if TYPE_CHECKING: from .. import models as _models -JSON = MutableMapping[str, Any] class BodyAuthGetOidcToken(_serialization.Model): @@ -224,12 +222,12 @@ class BodyPilotsAddPilotStamps(_serialization.Model): def __init__( self, *, - pilot_stamps: List[str], + pilot_stamps: list[str], vo: str, grid_type: str = "Dirac", grid_site: str = "Unknown", destination_site: str = "NotAssigned", - pilot_references: Optional[Dict[str, str]] = None, + pilot_references: Optional[dict[str, str]] = None, pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, **kwargs: Any ) -> None: @@ -277,7 +275,7 @@ class BodyPilotsUpdatePilotFields(_serialization.Model): "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, } - def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + def __init__(self, *, pilot_stamps_to_fields_mapping: list["_models.PilotFieldsMapping"], **kwargs: Any) -> None: """ :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. @@ -1706,10 +1704,6 @@ class ValidationError(_serialization.Model): :vartype msg: str :ivar type: Error Type. Required. :vartype type: str - :ivar input: Input. - :vartype input: any - :ivar ctx: Context. - :vartype ctx: JSON """ _validation = { @@ -1722,20 +1716,9 @@ class ValidationError(_serialization.Model): "loc": {"key": "loc", "type": "[object]"}, "msg": {"key": "msg", "type": "str"}, "type": {"key": "type", "type": "str"}, - "input": {"key": "input", "type": "object"}, - "ctx": {"key": "ctx", "type": "object"}, } - def __init__( - self, - *, - loc: list[Any], - msg: str, - type: str, - input: Optional[Any] = None, - ctx: Optional[JSON] = None, - **kwargs: Any - ) -> None: + def __init__(self, *, loc: list[Any], msg: str, type: str, **kwargs: Any) -> None: """ :keyword loc: Location. Required. :paramtype loc: list[any] @@ -1743,17 +1726,11 @@ def __init__( :paramtype msg: str :keyword type: Error Type. Required. :paramtype type: str - :keyword input: Input. - :paramtype input: any - :keyword ctx: Context. - :paramtype ctx: JSON """ super().__init__(**kwargs) self.loc = loc self.msg = msg self.type = type - self.input = input - self.ctx = ctx class VectorSearchSpec(_serialization.Model): diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 954a82b0a..9f211a610 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -584,7 +584,7 @@ def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: def build_pilots_delete_pilots_request( *, - pilot_stamps: Optional[List[str]] = None, + pilot_stamps: Optional[list[str]] = None, age_in_days: Optional[int] = None, delete_only_aborted: bool = False, **kwargs: Any @@ -3062,7 +3062,7 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte def delete_pilots( # pylint: disable=inconsistent-return-statements self, *, - pilot_stamps: Optional[List[str]] = None, + pilot_stamps: Optional[list[str]] = None, age_in_days: Optional[int] = None, delete_only_aborted: bool = False, **kwargs: Any @@ -3231,7 +3231,7 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements @distributed_trace def get_pilot_jobs( self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any - ) -> List[int]: + ) -> list[int]: """Get Pilot Jobs. Endpoint only for admins, to get jobs of a pilot. @@ -3255,7 +3255,7 @@ def get_pilot_jobs( _headers = kwargs.pop("headers", {}) or {} _params = kwargs.pop("params", {}) or {} - cls: ClsType[List[int]] = kwargs.pop("cls", None) + cls: ClsType[list[int]] = kwargs.pop("cls", None) _request = build_pilots_get_pilot_jobs_request( pilot_stamp=pilot_stamp, @@ -3292,7 +3292,7 @@ def search( per_page: int = 100, content_type: str = "application/json", **kwargs: Any - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search. Retrieve information about pilots. @@ -3320,7 +3320,7 @@ def search( per_page: int = 100, content_type: str = "application/json", **kwargs: Any - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search. Retrieve information about pilots. @@ -3347,7 +3347,7 @@ def search( page: int = 1, per_page: int = 100, **kwargs: Any - ) -> List[Dict[str, Any]]: + ) -> list[dict[str, Any]]: """Search. Retrieve information about pilots. @@ -3374,9 +3374,10 @@ def search( _params = kwargs.pop("params", {}) or {} content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) - cls: ClsType[List[Dict[str, Any]]] = kwargs.pop("cls", None) + content_type = content_type if body else None + cls: ClsType[list[dict[str, Any]]] = kwargs.pop("cls", None) - content_type = content_type or "application/json" + content_type = content_type or "application/json" if body else None _json = None _content = None if isinstance(body, (IOBase, bytes)): diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index b14e98b84..043341965 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -6,17 +6,23 @@ Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize """ + from __future__ import annotations __all__ = [ "AuthOperations", "JobsOperations", - "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations -from ...patches.pilots.sync import PilotsOperations + +try: + from ...patches.pilots.sync import PilotsOperations + + __all__.append("PilotsOperations") +except ImportError: + pass def patch_sdk(): diff --git a/diracx-core/src/diracx/core/models/job.py b/diracx-core/src/diracx/core/models/job.py index 94e6452b9..ec098c2c6 100644 --- a/diracx-core/src/diracx/core/models/job.py +++ b/diracx-core/src/diracx/core/models/job.py @@ -6,7 +6,7 @@ from __future__ import annotations from enum import StrEnum -from typing import Literal, Optional +from typing import Literal from pydantic import BaseModel, Field, field_validator diff --git a/diracx-core/src/diracx/core/models/search.py b/diracx-core/src/diracx/core/models/search.py index 9d35e86c9..d8baf2cb8 100644 --- a/diracx-core/src/diracx/core/models/search.py +++ b/diracx-core/src/diracx/core/models/search.py @@ -1,5 +1,6 @@ from __future__ import annotations +from datetime import datetime from enum import StrEnum from pydantic import BaseModel @@ -24,7 +25,7 @@ class VectorSearchOperator(StrEnum): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str | int + value: str | int | datetime class VectorSearchSpec(TypedDict): diff --git a/diracx-db/src/diracx/db/sql/dummy/db.py b/diracx-db/src/diracx/db/sql/dummy/db.py index 966b6381e..1899d930e 100644 --- a/diracx-db/src/diracx/db/sql/dummy/db.py +++ b/diracx-db/src/diracx/db/sql/dummy/db.py @@ -3,7 +3,7 @@ from sqlalchemy import insert from uuid_utils import UUID -from diracx.core.models import SearchSpec +from diracx.core.models.search import SearchSpec from diracx.db.sql.utils import BaseSQLDB from .schema import Base as DummyDBBase diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index 2cdf6bf39..c5b94f70a 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -76,11 +76,7 @@ async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): """Associate a pilot with jobs. job_to_pilot_mapping format: - ```py - job_to_pilot_mapping = [ - {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} - ] - ``` + job_to_pilot_mapping = [{"PilotID": pilot_id, "JobID": job_id, "StartTime": now}] Raises: - PilotNotFoundError if a pilot_id is not associated with a pilot. @@ -144,7 +140,7 @@ async def update_pilot_fields( """Bulk update pilots with a mapping. pilot_stamps_to_fields_mapping format: - ```py + [ { "PilotStamp": pilot_stamp, @@ -157,7 +153,6 @@ async def update_pilot_fields( ... } ] - ``` The mapping helps to update multiple fields at a time. diff --git a/diracx-db/src/diracx/db/sql/pilots/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py index 4e0fbb9b2..701739ff9 100644 --- a/diracx-db/src/diracx/db/sql/pilots/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -11,6 +11,7 @@ ) from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column +from diracx.core.models.pilot import PilotStatus from diracx.db.sql.utils import ( EnumBackedBool, str32, diff --git a/diracx-db/src/diracx/db/sql/utils/__init__.py b/diracx-db/src/diracx/db/sql/utils/__init__.py index ab29c76e8..ad262186b 100644 --- a/diracx-db/src/diracx/db/sql/utils/__init__.py +++ b/diracx-db/src/diracx/db/sql/utils/__init__.py @@ -43,19 +43,3 @@ str512, str1024, ) - -__all__ = ( - "_get_columns", - "apply_search_filters", - "apply_sort_constraints", - "BaseSQLDB", - "Column", - "DateNowColumn", - "EnumBackedBool", - "EnumColumn", - "hash", - "NullColumn", - "substract_date", - "SQLDBUnavailableError", - "utcno", -) diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index 011633d9b..ecacd2710 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -23,7 +23,8 @@ class ActionType(StrEnum): class PilotManagementAccessPolicy(BaseAccessPolicy): - """Rules: + """Pilot management access policy. + * Every user can access data about his VO * An administrator can modify a pilot. """ diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 946623cbb..34417d65f 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -5,7 +5,6 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- -from collections.abc import MutableMapping import datetime from typing import Any, Optional, TYPE_CHECKING, Union @@ -13,7 +12,6 @@ if TYPE_CHECKING: from .. import models as _models -JSON = MutableMapping[str, Any] class BodyAuthGetOidcToken(_serialization.Model): @@ -224,12 +222,12 @@ class BodyPilotsAddPilotStamps(_serialization.Model): def __init__( self, *, - pilot_stamps: List[str], + pilot_stamps: list[str], vo: str, grid_type: str = "Dirac", grid_site: str = "Unknown", destination_site: str = "NotAssigned", - pilot_references: Optional[Dict[str, str]] = None, + pilot_references: Optional[dict[str, str]] = None, pilot_status: Optional[Union[str, "_models.PilotStatus"]] = None, **kwargs: Any ) -> None: @@ -277,7 +275,7 @@ class BodyPilotsUpdatePilotFields(_serialization.Model): "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, } - def __init__(self, *, pilot_stamps_to_fields_mapping: List["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + def __init__(self, *, pilot_stamps_to_fields_mapping: list["_models.PilotFieldsMapping"], **kwargs: Any) -> None: """ :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. @@ -1727,10 +1725,6 @@ class ValidationError(_serialization.Model): :vartype msg: str :ivar type: Error Type. Required. :vartype type: str - :ivar input: Input. - :vartype input: any - :ivar ctx: Context. - :vartype ctx: JSON """ _validation = { @@ -1743,20 +1737,9 @@ class ValidationError(_serialization.Model): "loc": {"key": "loc", "type": "[object]"}, "msg": {"key": "msg", "type": "str"}, "type": {"key": "type", "type": "str"}, - "input": {"key": "input", "type": "object"}, - "ctx": {"key": "ctx", "type": "object"}, } - def __init__( - self, - *, - loc: list[Any], - msg: str, - type: str, - input: Optional[Any] = None, - ctx: Optional[JSON] = None, - **kwargs: Any - ) -> None: + def __init__(self, *, loc: list[Any], msg: str, type: str, **kwargs: Any) -> None: """ :keyword loc: Location. Required. :paramtype loc: list[any] @@ -1764,17 +1747,11 @@ def __init__( :paramtype msg: str :keyword type: Error Type. Required. :paramtype type: str - :keyword input: Input. - :paramtype input: any - :keyword ctx: Context. - :paramtype ctx: JSON """ super().__init__(**kwargs) self.loc = loc self.msg = msg self.type = type - self.input = input - self.ctx = ctx class VectorSearchSpec(_serialization.Model): From 90a3e7f9e3be7bec39c8c192be75c16c227097a6 Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Wed, 25 Mar 2026 11:05:46 +0100 Subject: [PATCH 5/9] chore: regenerate client --- .../client/_generated/models/_models.py | 25 ++++++++++++++++++- .../client/_generated/models/_models.py | 25 ++++++++++++++++++- 2 files changed, 48 insertions(+), 2 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index 7c7730dc7..a6306e67e 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -5,6 +5,7 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +from collections.abc import MutableMapping import datetime from typing import Any, Optional, TYPE_CHECKING, Union @@ -12,6 +13,7 @@ if TYPE_CHECKING: from .. import models as _models +JSON = MutableMapping[str, Any] class BodyAuthGetOidcToken(_serialization.Model): @@ -1704,6 +1706,10 @@ class ValidationError(_serialization.Model): :vartype msg: str :ivar type: Error Type. Required. :vartype type: str + :ivar input: Input. + :vartype input: any + :ivar ctx: Context. + :vartype ctx: JSON """ _validation = { @@ -1716,9 +1722,20 @@ class ValidationError(_serialization.Model): "loc": {"key": "loc", "type": "[object]"}, "msg": {"key": "msg", "type": "str"}, "type": {"key": "type", "type": "str"}, + "input": {"key": "input", "type": "object"}, + "ctx": {"key": "ctx", "type": "object"}, } - def __init__(self, *, loc: list[Any], msg: str, type: str, **kwargs: Any) -> None: + def __init__( + self, + *, + loc: list[Any], + msg: str, + type: str, + input: Optional[Any] = None, + ctx: Optional[JSON] = None, + **kwargs: Any + ) -> None: """ :keyword loc: Location. Required. :paramtype loc: list[any] @@ -1726,11 +1743,17 @@ def __init__(self, *, loc: list[Any], msg: str, type: str, **kwargs: Any) -> Non :paramtype msg: str :keyword type: Error Type. Required. :paramtype type: str + :keyword input: Input. + :paramtype input: any + :keyword ctx: Context. + :paramtype ctx: JSON """ super().__init__(**kwargs) self.loc = loc self.msg = msg self.type = type + self.input = input + self.ctx = ctx class VectorSearchSpec(_serialization.Model): diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 34417d65f..2e88298bb 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -5,6 +5,7 @@ # Changes may cause incorrect behavior and will be lost if the code is regenerated. # -------------------------------------------------------------------------- +from collections.abc import MutableMapping import datetime from typing import Any, Optional, TYPE_CHECKING, Union @@ -12,6 +13,7 @@ if TYPE_CHECKING: from .. import models as _models +JSON = MutableMapping[str, Any] class BodyAuthGetOidcToken(_serialization.Model): @@ -1725,6 +1727,10 @@ class ValidationError(_serialization.Model): :vartype msg: str :ivar type: Error Type. Required. :vartype type: str + :ivar input: Input. + :vartype input: any + :ivar ctx: Context. + :vartype ctx: JSON """ _validation = { @@ -1737,9 +1743,20 @@ class ValidationError(_serialization.Model): "loc": {"key": "loc", "type": "[object]"}, "msg": {"key": "msg", "type": "str"}, "type": {"key": "type", "type": "str"}, + "input": {"key": "input", "type": "object"}, + "ctx": {"key": "ctx", "type": "object"}, } - def __init__(self, *, loc: list[Any], msg: str, type: str, **kwargs: Any) -> None: + def __init__( + self, + *, + loc: list[Any], + msg: str, + type: str, + input: Optional[Any] = None, + ctx: Optional[JSON] = None, + **kwargs: Any + ) -> None: """ :keyword loc: Location. Required. :paramtype loc: list[any] @@ -1747,11 +1764,17 @@ def __init__(self, *, loc: list[Any], msg: str, type: str, **kwargs: Any) -> Non :paramtype msg: str :keyword type: Error Type. Required. :paramtype type: str + :keyword input: Input. + :paramtype input: any + :keyword ctx: Context. + :paramtype ctx: JSON """ super().__init__(**kwargs) self.loc = loc self.msg = msg self.type = type + self.input = input + self.ctx = ctx class VectorSearchSpec(_serialization.Model): From 5167d3732563705e71522599dd82b0878d784f1d Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Wed, 25 Mar 2026 14:35:03 +0100 Subject: [PATCH 6/9] fix: added __all__ --- diracx-db/src/diracx/db/sql/pilots/db.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index c5b94f70a..e88ae9bb7 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -1,5 +1,7 @@ from __future__ import annotations +__all__ = ["PilotAgentsDB"] + from datetime import datetime, timezone from typing import Any From c9ee8a6aa467e63a3a0bfe01cbf0b5ed2d34fb0a Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Wed, 25 Mar 2026 14:35:36 +0100 Subject: [PATCH 7/9] fix: added query validation --- diracx-routers/src/diracx/routers/pilots/query.py | 7 ++++--- diracx-routers/tests/pilots/test_query.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index 29001e0c7..8b956ec68 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -3,10 +3,11 @@ from http import HTTPStatus from typing import Annotated, Any -from fastapi import Body, Depends, Response +from fastapi import Body, Depends, Query, Response from diracx.core.models.search import SearchParams, SummaryParams from diracx.db.sql import PilotAgentsDB +from diracx.logic.pilots.query import MAX_PER_PAGE from diracx.logic.pilots.query import search as search_bl from diracx.logic.pilots.query import summary as summary_bl @@ -111,8 +112,8 @@ async def search( check_permissions: CheckPilotManagementPolicyCallable, response: Response, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], - page: int = 1, - per_page: int = 100, + page: Annotated[int, Query(ge=1)] = 1, + per_page: Annotated[int, Query(ge=1, le=MAX_PER_PAGE)] = 100, body: Annotated[ SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore ] = None, diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index eeea8423e..6f2258a82 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -133,8 +133,8 @@ async def _search( r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) - if r.status_code == 400: - # If we have a status_code 400, that means that the query failed + if r.status_code in (400, 422): + # If we have a status_code 400/422, that means that the query failed raise InvalidQueryError() return r.json(), r.headers From 0dc45cf12009f8f725c2c4b10567b0ae646b3cd8 Mon Sep 17 00:00:00 2001 From: Federico Stagni Date: Wed, 25 Mar 2026 14:42:38 +0100 Subject: [PATCH 8/9] chore: regenerate client --- .../diracx/client/_generated/operations/_operations.py | 4 ++-- diracx-db/src/diracx/db/sql/pilots/schema.py | 1 - diracx-db/tests/pilots/test_query.py | 9 ++++++++- diracx-db/tests/pilots/utils.py | 7 ++++++- diracx-logic/src/diracx/logic/pilots/query.py | 10 +++++++++- diracx-routers/tests/pilots/test_query.py | 9 ++++++++- 6 files changed, 33 insertions(+), 7 deletions(-) diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index 9f211a610..f36059331 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -654,9 +654,9 @@ def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: # Construct parameters if page is not None: - _params["page"] = _SERIALIZER.query("page", page, "int") + _params["page"] = _SERIALIZER.query("page", page, "int", minimum=1) if per_page is not None: - _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int") + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int", maximum=10000, minimum=1) # Construct headers if content_type is not None: diff --git a/diracx-db/src/diracx/db/sql/pilots/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py index 701739ff9..0cd95ff9a 100644 --- a/diracx-db/src/diracx/db/sql/pilots/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -19,7 +19,6 @@ str255, ) from diracx.db.sql.utils.types import SmarterDateTime -from diracx.core.models.pilot import PilotStatus class PilotAgentsDBBase(DeclarativeBase): diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py index d1e5b1da3..592329d8c 100644 --- a/diracx-db/tests/pilots/test_query.py +++ b/diracx-db/tests/pilots/test_query.py @@ -4,7 +4,14 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus -from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, SortDirection, SortSpec, VectorSearchOperator, VectorSearchSpec +from diracx.core.models.search import ( + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) from diracx.db.sql.pilots.db import PilotAgentsDB MAIN_VO = "lhcb" diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py index c7f2e2908..d803a3597 100644 --- a/diracx-db/tests/pilots/utils.py +++ b/diracx-db/tests/pilots/utils.py @@ -6,7 +6,12 @@ import pytest from sqlalchemy import update -from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, VectorSearchOperator, VectorSearchSpec +from diracx.core.models.search import ( + ScalarSearchOperator, + ScalarSearchSpec, + VectorSearchOperator, + VectorSearchSpec, +) from diracx.db.sql.pilots.db import PilotAgentsDB from diracx.db.sql.pilots.schema import PilotAgents diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index 7487e0bfc..edad8f7ec 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -5,7 +5,15 @@ from diracx.core.exceptions import PilotNotFoundError from diracx.core.models.pilot import PilotStatus -from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, SearchParams, SearchSpec, SummaryParams, VectorSearchOperator, VectorSearchSpec +from diracx.core.models.search import ( + ScalarSearchOperator, + ScalarSearchSpec, + SearchParams, + SearchSpec, + SummaryParams, + VectorSearchOperator, + VectorSearchSpec, +) from diracx.db.sql import PilotAgentsDB MAX_PER_PAGE = 10000 diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index 6f2258a82..c6c1c7e35 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -7,7 +7,14 @@ from diracx.core.exceptions import InvalidQueryError from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus -from diracx.core.models.search import ScalarSearchOperator, ScalarSearchSpec, SortDirection, SortSpec, VectorSearchOperator, VectorSearchSpec +from diracx.core.models.search import ( + ScalarSearchOperator, + ScalarSearchSpec, + SortDirection, + SortSpec, + VectorSearchOperator, + VectorSearchSpec, +) pytestmark = pytest.mark.enabled_dependencies( [ From 6e47a7ae5c9e47da59b033feb22cc605870c06c8 Mon Sep 17 00:00:00 2001 From: aldbr Date: Tue, 14 Apr 2026 11:11:42 +0200 Subject: [PATCH 9/9] fix: broad rework of the pilots stack --- .../_generated/aio/operations/_operations.py | 227 +++---- .../_generated/aio/operations/_patch.py | 9 +- .../client/_generated/models/__init__.py | 12 +- .../client/_generated/models/_models.py | 86 +-- .../_generated/operations/_operations.py | 262 +++---- .../client/_generated/operations/_patch.py | 9 +- .../src/diracx/client/patches/pilots/aio.py | 46 +- .../diracx/client/patches/pilots/common.py | 96 ++- .../src/diracx/client/patches/pilots/sync.py | 32 +- diracx-core/src/diracx/core/exceptions.py | 8 +- diracx-core/src/diracx/core/models/pilot.py | 42 +- diracx-core/src/diracx/core/models/search.py | 3 +- diracx-db/src/diracx/db/sql/pilots/db.py | 242 ++++--- diracx-db/src/diracx/db/sql/pilots/schema.py | 5 +- .../tests/pilots/test_pilot_management.py | 227 +++---- diracx-db/tests/pilots/test_query.py | 323 ++------- diracx-db/tests/pilots/utils.py | 128 +--- diracx-logic/src/diracx/logic/jobs/query.py | 87 ++- .../src/diracx/logic/pilots/management.py | 166 +++-- diracx-logic/src/diracx/logic/pilots/query.py | 277 ++++---- .../src/diracx/routers/jobs/query.py | 24 +- .../diracx/routers/pilots/access_policies.py | 95 +-- .../src/diracx/routers/pilots/management.py | 224 ++---- .../src/diracx/routers/pilots/query.py | 85 ++- .../tests/jobs/test_heartbeat_commands.py | 2 + diracx-routers/tests/jobs/test_query.py | 124 ++++ diracx-routers/tests/jobs/test_status.py | 2 + diracx-routers/tests/pilots/__init__.py | 0 .../tests/pilots/test_access_policy.py | 127 ++++ .../tests/pilots/test_pilot_creation.py | 298 +++----- diracx-routers/tests/pilots/test_query.py | 480 ++++--------- docs/admin/explanations/pilots.md | 11 + docs/dev/explanations/pilots.md | 62 +- .../src/gubbins/client/_generated/_client.py | 6 +- .../gubbins/client/_generated/aio/_client.py | 6 +- .../_generated/aio/operations/__init__.py | 2 + .../_generated/aio/operations/_operations.py | 566 +++++++++++++++ .../client/_generated/models/__init__.py | 12 +- .../client/_generated/models/_models.py | 86 +-- .../client/_generated/operations/__init__.py | 2 + .../_generated/operations/_operations.py | 643 ++++++++++++++++++ .../tests/test_gubbins_job_manager.py | 2 + pixi.lock | 10 +- 43 files changed, 2968 insertions(+), 2188 deletions(-) create mode 100644 diracx-routers/tests/pilots/__init__.py create mode 100644 diracx-routers/tests/pilots/test_access_policy.py create mode 100644 docs/admin/explanations/pilots.md diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py index a295bdfc9..032cac02e 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_operations.py @@ -51,12 +51,11 @@ build_jobs_summary_request, build_jobs_unassign_bulk_jobs_sandboxes_request, build_jobs_unassign_job_sandboxes_request, - build_pilots_add_pilot_stamps_request, build_pilots_delete_pilots_request, - build_pilots_get_pilot_jobs_request, + build_pilots_register_pilots_request, build_pilots_search_request, build_pilots_summary_request, - build_pilots_update_pilot_fields_request, + build_pilots_update_pilot_metadata_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -1959,6 +1958,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. @@ -1999,6 +2004,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: IO[bytes] :keyword page: Default value is 1. @@ -2038,6 +2049,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. @@ -2347,17 +2364,17 @@ def __init__(self, *args, **kwargs) -> None: self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") @overload - async def add_pilot_stamps( - self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + async def register_pilots( + self, body: _models.BodyPilotsRegisterPilots, *, content_type: str = "application/json", **kwargs: Any ) -> Any: - """Add Pilot Stamps. + """Register Pilots. - Endpoint where a you can create pilots with their references. + Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. :param body: Required. - :type body: ~_generated.models.BodyPilotsAddPilotStamps + :type body: ~_generated.models.BodyPilotsRegisterPilots :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2367,12 +2384,12 @@ async def add_pilot_stamps( """ @overload - async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: - """Add Pilot Stamps. + async def register_pilots(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Register Pilots. - Endpoint where a you can create pilots with their references. + Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. :param body: Required. :type body: IO[bytes] @@ -2385,15 +2402,15 @@ async def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "applic """ @distributed_trace_async - async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: - """Add Pilot Stamps. + async def register_pilots(self, body: Union[_models.BodyPilotsRegisterPilots, IO[bytes]], **kwargs: Any) -> Any: + """Register Pilots. - Endpoint where a you can create pilots with their references. + Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. - :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :param body: Is either a BodyPilotsRegisterPilots type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsRegisterPilots or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -2418,9 +2435,9 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + _json = self._serialize.body(body, "BodyPilotsRegisterPilots") - _request = build_pilots_add_pilot_stamps_request( + _request = build_pilots_register_pilots_request( content_type=content_type, json=_json, content=_content, @@ -2448,36 +2465,19 @@ async def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, I return deserialized # type: ignore @distributed_trace_async - async def delete_pilots( - self, - *, - pilot_stamps: Optional[list[str]] = None, - age_in_days: Optional[int] = None, - delete_only_aborted: bool = False, - **kwargs: Any - ) -> None: + async def delete_pilots(self, *, pilot_stamps: list[str], **kwargs: Any) -> None: """Delete Pilots. - Endpoint to delete a pilot. + Delete pilots by stamp. - Two features: + Deletes the pilot rows as well as their logs and job associations. + Age-based retention cleanup is deliberately *not* exposed here: it is + handled by the maintenance task worker. See + ``diracx.logic.pilots.management.delete_pilots``. - #. Or you provide pilot_stamps, so you can delete pilots by their stamp - #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - - Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. - - :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :keyword pilot_stamps: Stamps of the pilots to delete. Required. :paramtype pilot_stamps: list[str] - :keyword age_in_days: The number of days that define the maximum age of pilots to be - deleted.Pilots older than this age will be considered for deletion. Default value is None. - :paramtype age_in_days: int - :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is - 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake.This flag is only used for deletion by time. Default value - is False. - :paramtype delete_only_aborted: bool :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2497,8 +2497,6 @@ async def delete_pilots( _request = build_pilots_delete_pilots_request( pilot_stamps=pilot_stamps, - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, headers=_headers, params=_params, ) @@ -2519,17 +2517,18 @@ async def delete_pilots( return cls(pipeline_response, None, {}) # type: ignore @overload - async def update_pilot_fields( - self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + async def update_pilot_metadata( + self, body: _models.BodyPilotsUpdatePilotMetadata, *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Update Pilot Fields. + """Update Pilot Metadata. - Modify a field of a pilot. + Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. :param body: Required. - :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2539,14 +2538,15 @@ async def update_pilot_fields( """ @overload - async def update_pilot_fields( + async def update_pilot_metadata( self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Update Pilot Fields. + """Update Pilot Metadata. - Modify a field of a pilot. + Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. :param body: Required. :type body: IO[bytes] @@ -2559,17 +2559,18 @@ async def update_pilot_fields( """ @distributed_trace_async - async def update_pilot_fields( - self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + async def update_pilot_metadata( + self, body: Union[_models.BodyPilotsUpdatePilotMetadata, IO[bytes]], **kwargs: Any ) -> None: - """Update Pilot Fields. + """Update Pilot Metadata. - Modify a field of a pilot. + Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. - :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :param body: Is either a BodyPilotsUpdatePilotMetadata type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata or IO[bytes] :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -2594,9 +2595,9 @@ async def update_pilot_fields( if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + _json = self._serialize.body(body, "BodyPilotsUpdatePilotMetadata") - _request = build_pilots_update_pilot_fields_request( + _request = build_pilots_update_pilot_metadata_request( content_type=content_type, json=_json, content=_content, @@ -2619,61 +2620,6 @@ async def update_pilot_fields( if cls: return cls(pipeline_response, None, {}) # type: ignore - @distributed_trace_async - async def get_pilot_jobs( - self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any - ) -> list[int]: - """Get Pilot Jobs. - - Endpoint only for admins, to get jobs of a pilot. - - :keyword pilot_stamp: The stamp of the pilot. Default value is None. - :paramtype pilot_stamp: str - :keyword job_id: The ID of the job. Default value is None. - :paramtype job_id: int - :return: list of int - :rtype: list[int] - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = kwargs.pop("headers", {}) or {} - _params = kwargs.pop("params", {}) or {} - - cls: ClsType[list[int]] = kwargs.pop("cls", None) - - _request = build_pilots_get_pilot_jobs_request( - pilot_stamp=pilot_stamp, - job_id=job_id, - headers=_headers, - params=_params, - ) - _request.url = self._client.format_url(_request.url) - - _stream = False - pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - deserialized = self._deserialize("[int]", pipeline_response.http_response) - - if cls: - return cls(pipeline_response, deserialized, {}) # type: ignore - - return deserialized # type: ignore - @overload async def search( self, @@ -2688,6 +2634,14 @@ async def search( Retrieve information about pilots. + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + :param body: Default value is None. :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. @@ -2716,6 +2670,14 @@ async def search( Retrieve information about pilots. + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + :param body: Default value is None. :type body: IO[bytes] :keyword page: Default value is 1. @@ -2743,6 +2705,14 @@ async def search( Retrieve information about pilots. + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. @@ -2818,7 +2788,10 @@ async def summary( ) -> Any: """Summary. - Show information suitable for plotting. + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. :param body: Required. :type body: ~_generated.models.SummaryParams @@ -2834,7 +2807,10 @@ async def summary( async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. - Show information suitable for plotting. + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. :param body: Required. :type body: IO[bytes] @@ -2850,7 +2826,10 @@ async def summary(self, body: IO[bytes], *, content_type: str = "application/jso async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. - Show information suitable for plotting. + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. :param body: Is either a SummaryParams type or a IO[bytes] type. Required. :type body: ~_generated.models.SummaryParams or IO[bytes] diff --git a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py index fa7f62e3b..c950258bc 100644 --- a/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/aio/operations/_patch.py @@ -12,17 +12,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ....patches.auth.aio import AuthOperations from ....patches.jobs.aio import JobsOperations - -try: - from ....patches.pilots.aio import PilotsOperations - - __all__.append("PilotsOperations") -except ImportError: - pass +from ....patches.pilots.aio import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/_generated/models/__init__.py b/diracx-client/src/diracx/client/_generated/models/__init__.py index 8e1dbe20d..49b841e05 100644 --- a/diracx-client/src/diracx/client/_generated/models/__init__.py +++ b/diracx-client/src/diracx/client/_generated/models/__init__.py @@ -16,8 +16,8 @@ BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, - BodyPilotsAddPilotStamps, - BodyPilotsUpdatePilotFields, + BodyPilotsRegisterPilots, + BodyPilotsUpdatePilotMetadata, GroupInfo, HTTPValidationError, HeartbeatData, @@ -28,7 +28,7 @@ JobStatusUpdate, Metadata, OpenIDConfiguration, - PilotFieldsMapping, + PilotMetadata, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -67,8 +67,8 @@ "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", - "BodyPilotsAddPilotStamps", - "BodyPilotsUpdatePilotFields", + "BodyPilotsRegisterPilots", + "BodyPilotsUpdatePilotMetadata", "GroupInfo", "HTTPValidationError", "HeartbeatData", @@ -79,7 +79,7 @@ "JobStatusUpdate", "Metadata", "OpenIDConfiguration", - "PilotFieldsMapping", + "PilotMetadata", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", diff --git a/diracx-client/src/diracx/client/_generated/models/_models.py b/diracx-client/src/diracx/client/_generated/models/_models.py index a6306e67e..1b71411cf 100644 --- a/diracx-client/src/diracx/client/_generated/models/_models.py +++ b/diracx-client/src/diracx/client/_generated/models/_models.py @@ -184,12 +184,12 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids -class BodyPilotsAddPilotStamps(_serialization.Model): - """Body_pilots_add_pilot_stamps. +class BodyPilotsRegisterPilots(_serialization.Model): + """Body_pilots_register_pilots. All required parameters must be populated in order to send to server. - :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :ivar pilot_stamps: Stamps of the pilots to create. Required. :vartype pilot_stamps: list[str] :ivar vo: Pilot virtual organization. Required. :vartype vo: str @@ -201,8 +201,8 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :vartype destination_site: str :ivar pilot_references: Association of a pilot reference with a pilot stamp. :vartype pilot_references: dict[str, str] - :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", - "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :ivar pilot_status: Initial status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :vartype pilot_status: str or ~_generated.models.PilotStatus """ @@ -234,7 +234,7 @@ def __init__( **kwargs: Any ) -> None: """ - :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :keyword pilot_stamps: Stamps of the pilots to create. Required. :paramtype pilot_stamps: list[str] :keyword vo: Pilot virtual organization. Required. :paramtype vo: str @@ -246,7 +246,7 @@ def __init__( :paramtype destination_site: str :keyword pilot_references: Association of a pilot reference with a pilot stamp. :paramtype pilot_references: dict[str, str] - :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + :keyword pilot_status: Initial status of the pilots. Known values are: "Submitted", "Waiting", "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :paramtype pilot_status: str or ~_generated.models.PilotStatus """ @@ -260,31 +260,30 @@ def __init__( self.pilot_status = pilot_status -class BodyPilotsUpdatePilotFields(_serialization.Model): - """Body_pilots_update_pilot_fields. +class BodyPilotsUpdatePilotMetadata(_serialization.Model): + """Body_pilots_update_pilot_metadata. All required parameters must be populated in order to send to server. - :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. - :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + :ivar pilot_metadata: Pilot metadata mappings to apply. Required. + :vartype pilot_metadata: list[~_generated.models.PilotMetadata] """ _validation = { - "pilot_stamps_to_fields_mapping": {"required": True}, + "pilot_metadata": {"required": True}, } _attribute_map = { - "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + "pilot_metadata": {"key": "pilot_metadata", "type": "[PilotMetadata]"}, } - def __init__(self, *, pilot_stamps_to_fields_mapping: list["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + def __init__(self, *, pilot_metadata: list["_models.PilotMetadata"], **kwargs: Any) -> None: """ - :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. - Required. - :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + :keyword pilot_metadata: Pilot metadata mappings to apply. Required. + :paramtype pilot_metadata: list[~_generated.models.PilotMetadata] """ super().__init__(**kwargs) - self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + self.pilot_metadata = pilot_metadata class GroupInfo(_serialization.Model): @@ -1032,31 +1031,34 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported -class PilotFieldsMapping(_serialization.Model): - """All the fields that a user can modify on a Pilot (except PilotStamp). +class PilotMetadata(_serialization.Model): + """Mutable metadata attached to a pilot. + + ``PilotStamp`` identifies the pilot and cannot be changed. Every other + field is optional; when absent it is left untouched by an update. All required parameters must be populated in order to send to server. - :ivar pilot_stamp: Pilotstamp. Required. + :ivar pilot_stamp: Immutable stamp identifying the pilot. Required. :vartype pilot_stamp: str - :ivar status_reason: Statusreason. + :ivar status_reason: Human-readable reason for the current status. :vartype status_reason: str - :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", - "Failed", "Deleted", "Aborted", and "Unknown". + :ivar status: Current pilot status. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". :vartype status: str or ~_generated.models.PilotStatus - :ivar bench_mark: Benchmark. + :ivar bench_mark: Pilot benchmark value. :vartype bench_mark: float - :ivar destination_site: Destinationsite. + :ivar destination_site: Destination site. :vartype destination_site: str - :ivar queue: Queue. + :ivar queue: Batch queue name. :vartype queue: str - :ivar grid_site: Gridsite. + :ivar grid_site: Grid site. :vartype grid_site: str - :ivar grid_type: Gridtype. + :ivar grid_type: Grid type. :vartype grid_type: str - :ivar accounting_sent: Accountingsent. + :ivar accounting_sent: Whether accounting has been sent for this pilot. :vartype accounting_sent: bool - :ivar current_job_id: Currentjobid. + :ivar current_job_id: ID of the job currently running on this pilot. :vartype current_job_id: int """ @@ -1093,26 +1095,26 @@ def __init__( **kwargs: Any ) -> None: """ - :keyword pilot_stamp: Pilotstamp. Required. + :keyword pilot_stamp: Immutable stamp identifying the pilot. Required. :paramtype pilot_stamp: str - :keyword status_reason: Statusreason. + :keyword status_reason: Human-readable reason for the current status. :paramtype status_reason: str - :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", - "Failed", "Deleted", "Aborted", and "Unknown". + :keyword status: Current pilot status. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". :paramtype status: str or ~_generated.models.PilotStatus - :keyword bench_mark: Benchmark. + :keyword bench_mark: Pilot benchmark value. :paramtype bench_mark: float - :keyword destination_site: Destinationsite. + :keyword destination_site: Destination site. :paramtype destination_site: str - :keyword queue: Queue. + :keyword queue: Batch queue name. :paramtype queue: str - :keyword grid_site: Gridsite. + :keyword grid_site: Grid site. :paramtype grid_site: str - :keyword grid_type: Gridtype. + :keyword grid_type: Grid type. :paramtype grid_type: str - :keyword accounting_sent: Accountingsent. + :keyword accounting_sent: Whether accounting has been sent for this pilot. :paramtype accounting_sent: bool - :keyword current_job_id: Currentjobid. + :keyword current_job_id: ID of the job currently running on this pilot. :paramtype current_job_id: int """ super().__init__(**kwargs) diff --git a/diracx-client/src/diracx/client/_generated/operations/_operations.py b/diracx-client/src/diracx/client/_generated/operations/_operations.py index f36059331..ed74b3610 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_operations.py +++ b/diracx-client/src/diracx/client/_generated/operations/_operations.py @@ -565,7 +565,7 @@ def build_jobs_submit_jdl_jobs_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) -def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: +def build_pilots_register_pilots_request(**kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) @@ -582,30 +582,19 @@ def build_pilots_add_pilot_stamps_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) -def build_pilots_delete_pilots_request( - *, - pilot_stamps: Optional[list[str]] = None, - age_in_days: Optional[int] = None, - delete_only_aborted: bool = False, - **kwargs: Any -) -> HttpRequest: +def build_pilots_delete_pilots_request(*, pilot_stamps: list[str], **kwargs: Any) -> HttpRequest: _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) # Construct URL _url = "/api/pilots/" # Construct parameters - if pilot_stamps is not None: - _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") - if age_in_days is not None: - _params["age_in_days"] = _SERIALIZER.query("age_in_days", age_in_days, "int") - if delete_only_aborted is not None: - _params["delete_only_aborted"] = _SERIALIZER.query("delete_only_aborted", delete_only_aborted, "bool") + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) -def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: +def build_pilots_update_pilot_metadata_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) @@ -619,29 +608,6 @@ def build_pilots_update_pilot_fields_request(**kwargs: Any) -> HttpRequest: return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) -def build_pilots_get_pilot_jobs_request( - *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any -) -> HttpRequest: - _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) - _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) - - accept = _headers.pop("Accept", "application/json") - - # Construct URL - _url = "/api/pilots/jobs" - - # Construct parameters - if pilot_stamp is not None: - _params["pilot_stamp"] = _SERIALIZER.query("pilot_stamp", pilot_stamp, "str") - if job_id is not None: - _params["job_id"] = _SERIALIZER.query("job_id", job_id, "int") - - # Construct headers - _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") - - return HttpRequest(method="GET", url=_url, params=_params, headers=_headers, **kwargs) - - def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) @@ -2572,6 +2538,12 @@ def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. @@ -2612,6 +2584,12 @@ def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: IO[bytes] :keyword page: Default value is 1. @@ -2651,6 +2629,12 @@ def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. @@ -2958,17 +2942,17 @@ def __init__(self, *args, **kwargs) -> None: self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") @overload - def add_pilot_stamps( - self, body: _models.BodyPilotsAddPilotStamps, *, content_type: str = "application/json", **kwargs: Any + def register_pilots( + self, body: _models.BodyPilotsRegisterPilots, *, content_type: str = "application/json", **kwargs: Any ) -> Any: - """Add Pilot Stamps. + """Register Pilots. - Endpoint where a you can create pilots with their references. + Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. :param body: Required. - :type body: ~_generated.models.BodyPilotsAddPilotStamps + :type body: ~_generated.models.BodyPilotsRegisterPilots :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -2978,12 +2962,12 @@ def add_pilot_stamps( """ @overload - def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: - """Add Pilot Stamps. + def register_pilots(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Register Pilots. - Endpoint where a you can create pilots with their references. + Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. :param body: Required. :type body: IO[bytes] @@ -2996,15 +2980,15 @@ def add_pilot_stamps(self, body: IO[bytes], *, content_type: str = "application/ """ @distributed_trace - def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[bytes]], **kwargs: Any) -> Any: - """Add Pilot Stamps. + def register_pilots(self, body: Union[_models.BodyPilotsRegisterPilots, IO[bytes]], **kwargs: Any) -> Any: + """Register Pilots. - Endpoint where a you can create pilots with their references. + Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. - :param body: Is either a BodyPilotsAddPilotStamps type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsAddPilotStamps or IO[bytes] + :param body: Is either a BodyPilotsRegisterPilots type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsRegisterPilots or IO[bytes] :return: any :rtype: any :raises ~azure.core.exceptions.HttpResponseError: @@ -3029,9 +3013,9 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsAddPilotStamps") + _json = self._serialize.body(body, "BodyPilotsRegisterPilots") - _request = build_pilots_add_pilot_stamps_request( + _request = build_pilots_register_pilots_request( content_type=content_type, json=_json, content=_content, @@ -3060,35 +3044,20 @@ def add_pilot_stamps(self, body: Union[_models.BodyPilotsAddPilotStamps, IO[byte @distributed_trace def delete_pilots( # pylint: disable=inconsistent-return-statements - self, - *, - pilot_stamps: Optional[list[str]] = None, - age_in_days: Optional[int] = None, - delete_only_aborted: bool = False, - **kwargs: Any + self, *, pilot_stamps: list[str], **kwargs: Any ) -> None: """Delete Pilots. - Endpoint to delete a pilot. - - Two features: + Delete pilots by stamp. + Deletes the pilot rows as well as their logs and job associations. - #. Or you provide pilot_stamps, so you can delete pilots by their stamp - #. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. + Age-based retention cleanup is deliberately *not* exposed here: it is + handled by the maintenance task worker. See + ``diracx.logic.pilots.management.delete_pilots``. - Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. - - :keyword pilot_stamps: Stamps of the pilots we want to delete. Default value is None. + :keyword pilot_stamps: Stamps of the pilots to delete. Required. :paramtype pilot_stamps: list[str] - :keyword age_in_days: The number of days that define the maximum age of pilots to be - deleted.Pilots older than this age will be considered for deletion. Default value is None. - :paramtype age_in_days: int - :keyword delete_only_aborted: Flag indicating whether to only delete pilots whose status is - 'Aborted'.If set to True, only pilots with the 'Aborted' status will be deleted.It is set by - default as True to avoid any mistake.This flag is only used for deletion by time. Default value - is False. - :paramtype delete_only_aborted: bool :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -3108,8 +3077,6 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements _request = build_pilots_delete_pilots_request( pilot_stamps=pilot_stamps, - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, headers=_headers, params=_params, ) @@ -3130,17 +3097,18 @@ def delete_pilots( # pylint: disable=inconsistent-return-statements return cls(pipeline_response, None, {}) # type: ignore @overload - def update_pilot_fields( - self, body: _models.BodyPilotsUpdatePilotFields, *, content_type: str = "application/json", **kwargs: Any + def update_pilot_metadata( + self, body: _models.BodyPilotsUpdatePilotMetadata, *, content_type: str = "application/json", **kwargs: Any ) -> None: - """Update Pilot Fields. + """Update Pilot Metadata. - Modify a field of a pilot. + Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. :param body: Required. - :type body: ~_generated.models.BodyPilotsUpdatePilotFields + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. Default value is "application/json". :paramtype content_type: str @@ -3150,12 +3118,13 @@ def update_pilot_fields( """ @overload - def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: - """Update Pilot Fields. + def update_pilot_metadata(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Metadata. - Modify a field of a pilot. + Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. :param body: Required. :type body: IO[bytes] @@ -3168,17 +3137,18 @@ def update_pilot_fields(self, body: IO[bytes], *, content_type: str = "applicati """ @distributed_trace - def update_pilot_fields( # pylint: disable=inconsistent-return-statements - self, body: Union[_models.BodyPilotsUpdatePilotFields, IO[bytes]], **kwargs: Any + def update_pilot_metadata( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotMetadata, IO[bytes]], **kwargs: Any ) -> None: - """Update Pilot Fields. + """Update Pilot Metadata. - Modify a field of a pilot. + Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. - :param body: Is either a BodyPilotsUpdatePilotFields type or a IO[bytes] type. Required. - :type body: ~_generated.models.BodyPilotsUpdatePilotFields or IO[bytes] + :param body: Is either a BodyPilotsUpdatePilotMetadata type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata or IO[bytes] :return: None :rtype: None :raises ~azure.core.exceptions.HttpResponseError: @@ -3203,9 +3173,9 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements if isinstance(body, (IOBase, bytes)): _content = body else: - _json = self._serialize.body(body, "BodyPilotsUpdatePilotFields") + _json = self._serialize.body(body, "BodyPilotsUpdatePilotMetadata") - _request = build_pilots_update_pilot_fields_request( + _request = build_pilots_update_pilot_metadata_request( content_type=content_type, json=_json, content=_content, @@ -3228,61 +3198,6 @@ def update_pilot_fields( # pylint: disable=inconsistent-return-statements if cls: return cls(pipeline_response, None, {}) # type: ignore - @distributed_trace - def get_pilot_jobs( - self, *, pilot_stamp: Optional[str] = None, job_id: Optional[int] = None, **kwargs: Any - ) -> list[int]: - """Get Pilot Jobs. - - Endpoint only for admins, to get jobs of a pilot. - - :keyword pilot_stamp: The stamp of the pilot. Default value is None. - :paramtype pilot_stamp: str - :keyword job_id: The ID of the job. Default value is None. - :paramtype job_id: int - :return: list of int - :rtype: list[int] - :raises ~azure.core.exceptions.HttpResponseError: - """ - error_map: MutableMapping = { - 401: ClientAuthenticationError, - 404: ResourceNotFoundError, - 409: ResourceExistsError, - 304: ResourceNotModifiedError, - } - error_map.update(kwargs.pop("error_map", {}) or {}) - - _headers = kwargs.pop("headers", {}) or {} - _params = kwargs.pop("params", {}) or {} - - cls: ClsType[list[int]] = kwargs.pop("cls", None) - - _request = build_pilots_get_pilot_jobs_request( - pilot_stamp=pilot_stamp, - job_id=job_id, - headers=_headers, - params=_params, - ) - _request.url = self._client.format_url(_request.url) - - _stream = False - pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access - _request, stream=_stream, **kwargs - ) - - response = pipeline_response.http_response - - if response.status_code not in [200]: - map_error(status_code=response.status_code, response=response, error_map=error_map) - raise HttpResponseError(response=response) - - deserialized = self._deserialize("[int]", pipeline_response.http_response) - - if cls: - return cls(pipeline_response, deserialized, {}) # type: ignore - - return deserialized # type: ignore - @overload def search( self, @@ -3297,6 +3212,14 @@ def search( Retrieve information about pilots. + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + :param body: Default value is None. :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. @@ -3325,6 +3248,14 @@ def search( Retrieve information about pilots. + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + :param body: Default value is None. :type body: IO[bytes] :keyword page: Default value is 1. @@ -3352,6 +3283,14 @@ def search( Retrieve information about pilots. + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. @@ -3425,7 +3364,10 @@ def search( def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. - Show information suitable for plotting. + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. :param body: Required. :type body: ~_generated.models.SummaryParams @@ -3441,7 +3383,10 @@ def summary(self, body: _models.SummaryParams, *, content_type: str = "applicati def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: """Summary. - Show information suitable for plotting. + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. :param body: Required. :type body: IO[bytes] @@ -3457,7 +3402,10 @@ def summary(self, body: IO[bytes], *, content_type: str = "application/json", ** def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: """Summary. - Show information suitable for plotting. + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. :param body: Is either a SummaryParams type or a IO[bytes] type. Required. :type body: ~_generated.models.SummaryParams or IO[bytes] diff --git a/diracx-client/src/diracx/client/_generated/operations/_patch.py b/diracx-client/src/diracx/client/_generated/operations/_patch.py index 043341965..dd8ba9d0b 100644 --- a/diracx-client/src/diracx/client/_generated/operations/_patch.py +++ b/diracx-client/src/diracx/client/_generated/operations/_patch.py @@ -12,17 +12,12 @@ __all__ = [ "AuthOperations", "JobsOperations", + "PilotsOperations", ] # Add all objects you want publicly available to users at this package level from ...patches.auth.sync import AuthOperations from ...patches.jobs.sync import JobsOperations - -try: - from ...patches.pilots.sync import PilotsOperations - - __all__.append("PilotsOperations") -except ImportError: - pass +from ...patches.pilots.sync import PilotsOperations def patch_sdk(): diff --git a/diracx-client/src/diracx/client/patches/pilots/aio.py b/diracx-client/src/diracx/client/patches/pilots/aio.py index ac533a67c..95d1d5f5d 100644 --- a/diracx-client/src/diracx/client/patches/pilots/aio.py +++ b/diracx-client/src/diracx/client/patches/pilots/aio.py @@ -1,9 +1,4 @@ -"""Patches for the autorest-generated pilots client. - -This file can be used to customize the generated code for the pilots client. -When adding new classes to this file, make sure to also add them to the -__all__ list in the corresponding file in the patches directory. -""" +"""Patches for the autorest-generated async pilots client.""" from __future__ import annotations @@ -15,16 +10,16 @@ from azure.core.tracing.decorator_async import distributed_trace_async -from ..._generated.aio.operations._operations import PilotsOperations as _PilotsOperations +from ..._generated.aio.operations._operations import ( + PilotsOperations as _PilotsOperations, +) from .common import ( - make_search_body, - make_summary_body, - make_add_pilot_stamps_body, - make_update_pilot_fields_body, + RegisterPilotsKwargs, SearchKwargs, SummaryKwargs, - AddPilotStampsKwargs, - UpdatePilotFieldsKwargs + make_register_pilots_body, + make_search_body, + make_summary_body, ) # We're intentionally ignoring overrides here because we want to change the interface. @@ -33,21 +28,22 @@ class PilotsOperations(_PilotsOperations): @distributed_trace_async - async def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: - """TODO""" + async def search( + self, **kwargs: Unpack[SearchKwargs] + ) -> list[dict[str, Any]]: + """Search for pilots matching the provided filters.""" return await super().search(**make_search_body(**kwargs)) @distributed_trace_async - async def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: - """TODO""" + async def summary( + self, **kwargs: Unpack[SummaryKwargs] + ) -> list[dict[str, Any]]: + """Return pilot counts aggregated by the requested columns.""" return await super().summary(**make_summary_body(**kwargs)) @distributed_trace_async - async def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: - """TODO""" - return await super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) - - @distributed_trace_async - async def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: - """TODO""" - return await super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) + async def register_pilots( + self, **kwargs: Unpack[RegisterPilotsKwargs] + ) -> None: + """Register a batch of pilots.""" + return await super().register_pilots(**make_register_pilots_body(**kwargs)) diff --git a/diracx-client/src/diracx/client/patches/pilots/common.py b/diracx-client/src/diracx/client/patches/pilots/common.py index fd54786f2..80e194c57 100644 --- a/diracx-client/src/diracx/client/patches/pilots/common.py +++ b/diracx-client/src/diracx/client/patches/pilots/common.py @@ -1,23 +1,21 @@ -"""Utilities which are common to the sync and async pilots operator patches.""" +"""Utilities shared by the sync and async pilots operator patches.""" from __future__ import annotations __all__ = [ - "make_search_body", "SearchKwargs", - "make_summary_body", + "make_search_body", "SummaryKwargs", - "AddPilotStampsKwargs", - "make_add_pilot_stamps_body", - "UpdatePilotFieldsKwargs", - "make_update_pilot_fields_body" + "make_summary_body", + "RegisterPilotsKwargs", + "make_register_pilots_body", ] import json from io import BytesIO -from typing import Any, IO, TypedDict, Unpack, cast, Literal +from typing import IO, Any, Literal, TypedDict, Unpack, cast -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.pilot import PilotStatus from diracx.core.models.search import SearchSpec @@ -29,6 +27,8 @@ class ResponseExtra(TypedDict, total=False): # ------------------ Search ------------------ + + class SearchBody(TypedDict, total=False): parameters: list[str] | None search: list[SearchSpec] | None @@ -44,8 +44,8 @@ class SearchKwargs(SearchBody, SearchExtra): ... class UnderlyingSearchArgs(ResponseExtra, total=False): - # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite - # the code being generated to support IO[bytes] | bytes. + # FIXME: The autorest-generated operation expects IO[bytes] despite its + # signature advertising IO[bytes] | bytes. body: IO[bytes] @@ -62,19 +62,21 @@ def make_search_body(**kwargs: Unpack[SearchKwargs]) -> UnderlyingSearchArgs: result.update(cast(SearchExtra, kwargs)) return result + # ------------------ Summary ------------------ + class SummaryBody(TypedDict, total=False): grouping: list[str] - search: list[str] + search: list[SearchSpec] class SummaryKwargs(SummaryBody, ResponseExtra): ... class UnderlyingSummaryArgs(ResponseExtra, total=False): - # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite - # the code being generated to support IO[bytes] | bytes. + # FIXME: The autorest-generated operation expects IO[bytes] despite its + # signature advertising IO[bytes] | bytes. body: IO[bytes] @@ -91,57 +93,53 @@ def make_summary_body(**kwargs: Unpack[SummaryKwargs]) -> UnderlyingSummaryArgs: result.update(cast(ResponseExtra, kwargs)) return result -# ------------------ AddPilotStamps ------------------ -class AddPilotStampsBody(TypedDict, total=False): +# ------------------ Register pilots ------------------ + + +class RegisterPilotsBody(TypedDict, total=False): pilot_stamps: list[str] + vo: str grid_type: str grid_site: str + destination_site: str pilot_references: dict[str, str] pilot_status: PilotStatus - vo: str -class AddPilotStampsKwargs(AddPilotStampsBody, ResponseExtra): ... -class UnderlyingAddPilotStampsArgs(ResponseExtra, total=False): - # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite - # the code being generated to support IO[bytes] | bytes. - body: IO[bytes] - -def make_add_pilot_stamps_body(**kwargs: Unpack[AddPilotStampsKwargs]) -> UnderlyingAddPilotStampsArgs: - body: AddPilotStampsBody = {} - for key in AddPilotStampsBody.__optional_keys__: - if key not in kwargs: - continue - key = cast(Literal["pilot_stamps", "grid_type", "grid_site", "pilot_references", "pilot_status", "vo"], key) - value = kwargs.pop(key) - if value is not None: - body[key] = value - result: UnderlyingAddPilotStampsArgs = {"body": BytesIO(json.dumps(body).encode("utf-8"))} - result.update(cast(ResponseExtra, kwargs)) - return result +class RegisterPilotsKwargs(RegisterPilotsBody, ResponseExtra): ... -# ------------------ UpdatePilotFields ------------------ -class UpdatePilotFieldsBody(TypedDict, total=False): - pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] - -class UpdatePilotFieldsKwargs(UpdatePilotFieldsBody, ResponseExtra): ... - -class UnderlyingUpdatePilotFields(ResponseExtra, total=False): - # FIXME: The autorest-generated has a bug that it expected IO[bytes] despite - # the code being generated to support IO[bytes] | bytes. +class UnderlyingRegisterPilotsArgs(ResponseExtra, total=False): + # FIXME: The autorest-generated operation expects IO[bytes] despite its + # signature advertising IO[bytes] | bytes. body: IO[bytes] -def make_update_pilot_fields_body(**kwargs: Unpack[UpdatePilotFieldsKwargs]) -> UnderlyingUpdatePilotFields: - body: UpdatePilotFieldsBody = {} - for key in UpdatePilotFieldsBody.__optional_keys__: + +def make_register_pilots_body( + **kwargs: Unpack[RegisterPilotsKwargs], +) -> UnderlyingRegisterPilotsArgs: + body: RegisterPilotsBody = {} + for key in RegisterPilotsBody.__optional_keys__: if key not in kwargs: continue - key = cast(Literal["pilot_stamps_to_fields_mapping"], key) + key = cast( + Literal[ + "pilot_stamps", + "vo", + "grid_type", + "grid_site", + "destination_site", + "pilot_references", + "pilot_status", + ], + key, + ) value = kwargs.pop(key) if value is not None: body[key] = value - result: UnderlyingUpdatePilotFields = {"body": BytesIO(json.dumps(body).encode("utf-8"))} + result: UnderlyingRegisterPilotsArgs = { + "body": BytesIO(json.dumps(body).encode("utf-8")) + } result.update(cast(ResponseExtra, kwargs)) return result diff --git a/diracx-client/src/diracx/client/patches/pilots/sync.py b/diracx-client/src/diracx/client/patches/pilots/sync.py index 744cee161..04c60116d 100644 --- a/diracx-client/src/diracx/client/patches/pilots/sync.py +++ b/diracx-client/src/diracx/client/patches/pilots/sync.py @@ -1,9 +1,4 @@ -"""Patches for the autorest-generated pilots client. - -This file can be used to customize the generated code for the pilots client. -When adding new classes to this file, make sure to also add them to the -__all__ list in the corresponding file in the patches directory. -""" +"""Patches for the autorest-generated sync pilots client.""" from __future__ import annotations @@ -17,14 +12,12 @@ from ..._generated.operations._operations import PilotsOperations as _PilotsOperations from .common import ( - make_search_body, - make_summary_body, - make_add_pilot_stamps_body, - make_update_pilot_fields_body, + RegisterPilotsKwargs, SearchKwargs, SummaryKwargs, - AddPilotStampsKwargs, - UpdatePilotFieldsKwargs + make_register_pilots_body, + make_search_body, + make_summary_body, ) # We're intentionally ignoring overrides here because we want to change the interface. @@ -34,20 +27,15 @@ class PilotsOperations(_PilotsOperations): @distributed_trace def search(self, **kwargs: Unpack[SearchKwargs]) -> list[dict[str, Any]]: - """TODO""" + """Search for pilots matching the provided filters.""" return super().search(**make_search_body(**kwargs)) @distributed_trace def summary(self, **kwargs: Unpack[SummaryKwargs]) -> list[dict[str, Any]]: - """TODO""" + """Return pilot counts aggregated by the requested columns.""" return super().summary(**make_summary_body(**kwargs)) @distributed_trace - def add_pilot_stamps(self, **kwargs: Unpack[AddPilotStampsKwargs]) -> None: - """TODO""" - return super().add_pilot_stamps(**make_add_pilot_stamps_body(**kwargs)) - - @distributed_trace - def update_pilot_fields(self, **kwargs: Unpack[UpdatePilotFieldsKwargs]) -> None: - """TODO""" - return super().update_pilot_fields(**make_update_pilot_fields_body(**kwargs)) + def register_pilots(self, **kwargs: Unpack[RegisterPilotsKwargs]) -> None: + """Register a batch of pilots.""" + return super().register_pilots(**make_register_pilots_body(**kwargs)) diff --git a/diracx-core/src/diracx/core/exceptions.py b/diracx-core/src/diracx/core/exceptions.py index 19d8d5a41..1ff78eff9 100644 --- a/diracx-core/src/diracx/core/exceptions.py +++ b/diracx-core/src/diracx/core/exceptions.py @@ -106,10 +106,16 @@ class NotReadyError(DiracError): class PilotNotFoundError(DiracError): """At least one pilot is not found.""" + http_status_code = HTTPStatus.NOT_FOUND + class PilotAlreadyExistsError(DiracError): - """At least one pilot already exists, we avoid collitions.""" + """At least one pilot already exists, we avoid collisions.""" + + http_status_code = HTTPStatus.CONFLICT class PilotAlreadyAssociatedWithJobError(DiracError): """We can't associate a pilot with the same job twice.""" + + http_status_code = HTTPStatus.CONFLICT diff --git a/diracx-core/src/diracx/core/models/pilot.py b/diracx-core/src/diracx/core/models/pilot.py index 7abba1378..216beb01d 100644 --- a/diracx-core/src/diracx/core/models/pilot.py +++ b/diracx-core/src/diracx/core/models/pilot.py @@ -4,7 +4,7 @@ from enum import StrEnum -from pydantic import BaseModel +from pydantic import BaseModel, Field class PilotStatus(StrEnum): @@ -18,16 +18,30 @@ class PilotStatus(StrEnum): UNKNOWN = "Unknown" -class PilotFieldsMapping(BaseModel, extra="forbid"): - """All the fields that a user can modify on a Pilot (except PilotStamp).""" - - PilotStamp: str - StatusReason: str | None = None - Status: PilotStatus | None = None - BenchMark: float | None = None - DestinationSite: str | None = None - Queue: str | None = None - GridSite: str | None = None - GridType: str | None = None - AccountingSent: bool | None = None - CurrentJobID: int | None = None +class PilotMetadata(BaseModel, extra="forbid"): + """Mutable metadata attached to a pilot. + + ``PilotStamp`` identifies the pilot and cannot be changed. Every other + field is optional; when absent it is left untouched by an update. + """ + + PilotStamp: str = Field(description="Immutable stamp identifying the pilot.") + StatusReason: str | None = Field( + default=None, description="Human-readable reason for the current status." + ) + Status: PilotStatus | None = Field( + default=None, description="Current pilot status." + ) + BenchMark: float | None = Field(default=None, description="Pilot benchmark value.") + DestinationSite: str | None = Field(default=None, description="Destination site.") + Queue: str | None = Field(default=None, description="Batch queue name.") + GridSite: str | None = Field(default=None, description="Grid site.") + GridType: str | None = Field(default=None, description="Grid type.") + AccountingSent: bool | None = Field( + default=None, + description="Whether accounting has been sent for this pilot.", + ) + CurrentJobID: int | None = Field( + default=None, + description="ID of the job currently running on this pilot.", + ) diff --git a/diracx-core/src/diracx/core/models/search.py b/diracx-core/src/diracx/core/models/search.py index d8baf2cb8..9d35e86c9 100644 --- a/diracx-core/src/diracx/core/models/search.py +++ b/diracx-core/src/diracx/core/models/search.py @@ -1,6 +1,5 @@ from __future__ import annotations -from datetime import datetime from enum import StrEnum from pydantic import BaseModel @@ -25,7 +24,7 @@ class VectorSearchOperator(StrEnum): class ScalarSearchSpec(TypedDict): parameter: str operator: ScalarSearchOperator - value: str | int | datetime + value: str | int class VectorSearchSpec(TypedDict): diff --git a/diracx-db/src/diracx/db/sql/pilots/db.py b/diracx-db/src/diracx/db/sql/pilots/db.py index e88ae9bb7..b65ede5b2 100644 --- a/diracx-db/src/diracx/db/sql/pilots/db.py +++ b/diracx-db/src/diracx/db/sql/pilots/db.py @@ -5,20 +5,18 @@ from datetime import datetime, timezone from typing import Any -from sqlalchemy import bindparam +from sqlalchemy import case, delete, insert, literal, select, update from sqlalchemy.exc import IntegrityError -from sqlalchemy.sql import delete, insert, update +from sqlalchemy.sql import expression from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, PilotNotFoundError, ) -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.pilot import PilotMetadata, PilotStatus from diracx.core.models.search import SearchSpec, SortSpec -from ..utils import ( - BaseSQLDB, -) +from ..utils import BaseSQLDB from .schema import ( JobToPilotMapping, PilotAgents, @@ -28,13 +26,11 @@ class PilotAgentsDB(BaseSQLDB): - """PilotAgentsDB class is a front-end to the PilotAgents Database.""" + """Front-end to the PilotAgents database.""" metadata = PilotAgentsDBBase.metadata - # ----------------------------- Insert Functions ----------------------------- - - async def add_pilots( + async def register_pilots( self, pilot_stamps: list[str], vo: str, @@ -44,16 +40,16 @@ async def add_pilots( pilot_references: dict[str, str] | None = None, status: str = PilotStatus.SUBMITTED, ): - """Bulk add pilots in the DB. + """Bulk-register pilots. - If we can't find a pilot_reference associated with a stamp, we take the stamp by default. + If a stamp has no entry in `pilot_references` the stamp is used as + the reference. """ if pilot_references is None: pilot_references = {} now = datetime.now(tz=timezone.utc) - # Prepare the list of dictionaries for bulk insertion values = [ { "PilotJobReference": pilot_references.get(stamp, stamp), @@ -69,125 +65,103 @@ async def add_pilots( for stamp in pilot_stamps ] - # Insert multiple rows in a single execute call and use 'returning' to get primary keys - stmt = insert(PilotAgents).values(values) # Assuming 'id' is the primary key + await self.conn.execute(insert(PilotAgents).values(values)) - await self.conn.execute(stmt) - - async def add_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): + async def assign_jobs_to_pilot(self, job_to_pilot_mapping: list[dict[str, Any]]): """Associate a pilot with jobs. - job_to_pilot_mapping format: - job_to_pilot_mapping = [{"PilotID": pilot_id, "JobID": job_id, "StartTime": now}] - - Raises: - - PilotNotFoundError if a pilot_id is not associated with a pilot. - - PilotAlreadyAssociatedWithJobError if the pilot is already associated with one of the given jobs. - - NotImplementedError if the integrity error is not caught. - - **Important note**: We assume that a job exists. - + Each entry has the shape `{"PilotID": ..., "JobID": ..., "StartTime": ...}`. + Raises PilotNotFoundError if any pilot is missing, and + PilotAlreadyAssociatedWithJobError on duplicates. Caller must + ensure the jobs exist. """ - # Insert multiple rows in a single execute call stmt = insert(JobToPilotMapping).values(job_to_pilot_mapping) try: await self.conn.execute(stmt) except IntegrityError as e: - if "foreign key" in str(e.orig).lower(): + msg = str(e.orig).lower() + if "foreign key" in msg: raise PilotNotFoundError( - detail="at least one of these pilots do not exist", + detail="at least one of these pilots does not exist", ) from e - - if ( - "duplicate entry" in str(e.orig).lower() - or "unique constraint" in str(e.orig).lower() - ): + if "duplicate entry" in msg or "unique constraint" in msg: raise PilotAlreadyAssociatedWithJobError( - detail="at least one of these pilots is already associated with a given job." + detail=( + "at least one of these pilots is already associated " + "with a given job." + ) ) from e - - # Other errors to catch - raise NotImplementedError( - "Engine Specific error not caught" + str(e) - ) from e - - # ----------------------------- Delete Functions ----------------------------- + raise async def delete_pilots(self, pilot_ids: list[int]): - """Destructive function. Delete pilots.""" - stmt = delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) - - await self.conn.execute(stmt) + """Destructive. Delete pilots by ID.""" + await self.conn.execute( + delete(PilotAgents).where(PilotAgents.pilot_id.in_(pilot_ids)) + ) async def remove_jobs_from_pilots(self, pilot_ids: list[int]): - """Destructive function. De-associate jobs and pilots.""" - stmt = delete(JobToPilotMapping).where( - JobToPilotMapping.pilot_id.in_(pilot_ids) + """Destructive. De-associate jobs and pilots.""" + await self.conn.execute( + delete(JobToPilotMapping).where(JobToPilotMapping.pilot_id.in_(pilot_ids)) ) - await self.conn.execute(stmt) - async def delete_pilot_logs(self, pilot_ids: list[int]): - """Destructive function. Remove logs from pilots.""" - stmt = delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) - - await self.conn.execute(stmt) + """Destructive. Remove pilot logs.""" + await self.conn.execute( + delete(PilotOutput).where(PilotOutput.pilot_id.in_(pilot_ids)) + ) - # ----------------------------- Update Functions ----------------------------- + async def update_pilot_metadata(self, pilot_metadata: list[PilotMetadata]): + """Bulk-update pilot metadata. - async def update_pilot_fields( - self, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] - ): - """Bulk update pilots with a mapping. - - pilot_stamps_to_fields_mapping format: - - [ - { - "PilotStamp": pilot_stamp, - "BenchMark": bench_mark, - "StatusReason": pilot_reason, - "AccountingSent": accounting_sent, - "Status": status, - "CurrentJobID": current_job_id, - "Queue": queue, - ... - } - ] - - The mapping helps to update multiple fields at a time. - - Raises PilotNotFoundError if one of the pilots is not found. + Each PilotMetadata entry may set a different subset of fields; + unset fields (None) are preserved. Uses a per-column CASE + expression to support heterogeneous updates, matching the pattern + in JobDB.set_job_attributes. Raises PilotNotFoundError if any of + the pilot stamps is not found. """ + if not pilot_metadata: + return + + updates_by_stamp: dict[str, dict[str, Any]] = { + m.PilotStamp: m.model_dump(exclude={"PilotStamp"}, exclude_none=True) + for m in pilot_metadata + } + + columns = {col for fields in updates_by_stamp.values() for col in fields} + if not columns: + return + + case_expressions = { + column: case( + *[ + ( + PilotAgents.__table__.c.PilotStamp == stamp, + literal( + fields[column], + type_=PilotAgents.__table__.c[column].type, + ) + if not isinstance(fields[column], expression.FunctionElement) + else fields[column], + ) + for stamp, fields in updates_by_stamp.items() + if column in fields + ], + else_=getattr(PilotAgents.__table__.c, column), + ) + for column in columns + } + stmt = ( update(PilotAgents) - .where(PilotAgents.pilot_stamp == bindparam("b_pilot_stamp")) - .values( - { - key: bindparam(key) - for key in pilot_stamps_to_fields_mapping[0] - .model_dump(exclude_none=True) - .keys() - if key != "PilotStamp" - } - ) + .values(**case_expressions) + .where(PilotAgents.__table__.c.PilotStamp.in_(updates_by_stamp.keys())) ) + result = await self.conn.execute(stmt) - values = [ - { - **{"b_pilot_stamp": mapping.PilotStamp}, - **mapping.model_dump(exclude={"PilotStamp"}, exclude_none=True), - } - for mapping in pilot_stamps_to_fields_mapping - ] - - res = await self.conn.execute(stmt, values) - - if res.rowcount != len(pilot_stamps_to_fields_mapping): - raise PilotNotFoundError("at least one of the given pilot does not exist.") - - # ----------------------------- Search Functions ----------------------------- + if result.rowcount != len(updates_by_stamp): + raise PilotNotFoundError("at least one of the given pilots does not exist.") async def search_pilots( self, @@ -198,7 +172,7 @@ async def search_pilots( distinct: bool = False, per_page: int = 100, page: int | None = None, - ) -> tuple[int, list[dict[Any, Any]]]: + ) -> tuple[int, list[dict[str, Any]]]: """Search for pilot information in the database.""" return await self._search( table=PilotAgents, @@ -210,29 +184,41 @@ async def search_pilots( page=page, ) - async def search_pilot_to_job_mapping( - self, - parameters: list[str] | None, - search: list[SearchSpec], - sorts: list[SortSpec], - *, - distinct: bool = False, - per_page: int = 100, - page: int | None = None, - ) -> tuple[int, list[dict[Any, Any]]]: - """Search for jobs that are associated with pilots.""" - return await self._search( - table=JobToPilotMapping, - parameters=parameters, - search=search, - sorts=sorts, - distinct=distinct, - per_page=per_page, - page=page, - ) - async def pilot_summary( self, group_by: list[str], search: list[SearchSpec] ) -> list[dict[str, str | int]]: - """Get a summary of the pilots.""" + """Aggregate pilot counts by the requested columns.""" return await self._summary(table=PilotAgents, group_by=group_by, search=search) + + async def job_ids_for_stamps(self, pilot_stamps: list[str]) -> list[int]: + """Return the IDs of jobs that have run on any of the given pilot stamps. + + Single round-trip SQL join over JobToPilotMapping and PilotAgents + (both live in the same metadata, so the join is legitimate at the + DB layer). + """ + if not pilot_stamps: + return [] + stmt = ( + select(JobToPilotMapping.job_id) + .join( + PilotAgents, + PilotAgents.pilot_id == JobToPilotMapping.pilot_id, + ) + .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) + .distinct() + ) + result = await self.conn.execute(stmt) + return [row[0] for row in result] + + async def pilot_ids_for_job_ids(self, job_ids: list[int]) -> list[int]: + """Return the IDs of pilots that have run any of the given jobs.""" + if not job_ids: + return [] + stmt = ( + select(JobToPilotMapping.pilot_id) + .where(JobToPilotMapping.job_id.in_(job_ids)) + .distinct() + ) + result = await self.conn.execute(stmt) + return [row[0] for row in result] diff --git a/diracx-db/src/diracx/db/sql/pilots/schema.py b/diracx-db/src/diracx/db/sql/pilots/schema.py index 0cd95ff9a..ba6c65a86 100644 --- a/diracx-db/src/diracx/db/sql/pilots/schema.py +++ b/diracx-db/src/diracx/db/sql/pilots/schema.py @@ -1,7 +1,6 @@ from __future__ import annotations from datetime import datetime -from typing import Optional from sqlalchemy import ( Double, @@ -49,10 +48,10 @@ class PilotAgents(PilotAgentsDBBase): vo: Mapped[str128] = mapped_column("VO") grid_type: Mapped[str32] = mapped_column("GridType", default="LCG") benchmark: Mapped[float] = mapped_column("BenchMark", Double, default=0.0) - submission_time: Mapped[Optional[datetime]] = mapped_column( + submission_time: Mapped[datetime | None] = mapped_column( "SubmissionTime", SmarterDateTime ) - last_update_time: Mapped[Optional[datetime]] = mapped_column( + last_update_time: Mapped[datetime | None] = mapped_column( "LastUpdateTime", SmarterDateTime ) status: Mapped[str32] = mapped_column("Status", default=PilotStatus.UNKNOWN) diff --git a/diracx-db/tests/pilots/test_pilot_management.py b/diracx-db/tests/pilots/test_pilot_management.py index 2adabb0d2..e9b4e4a32 100644 --- a/diracx-db/tests/pilots/test_pilot_management.py +++ b/diracx-db/tests/pilots/test_pilot_management.py @@ -3,27 +3,22 @@ from datetime import datetime, timezone import pytest +import pytest_asyncio from diracx.core.exceptions import ( PilotAlreadyAssociatedWithJobError, + PilotNotFoundError, ) -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.pilot import PilotMetadata, PilotStatus from diracx.db.sql.pilots.db import PilotAgentsDB -from .utils import ( - add_stamps, # noqa: F401 - create_old_pilots_environment, # noqa: F401 - create_timed_pilots, # noqa: F401 - get_pilot_jobs_ids_by_pilot_id, - get_pilots_by_stamp, -) +from .utils import get_pilots_by_stamp MAIN_VO = "lhcb" -N = 100 -@pytest.fixture -async def pilot_db(tmp_path): +@pytest_asyncio.fixture +async def pilot_db(): agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") async with agents_db.engine_context(): async with agents_db.engine.begin() as conn: @@ -32,162 +27,102 @@ async def pilot_db(tmp_path): @pytest.mark.asyncio -async def test_insert_and_select(pilot_db: PilotAgentsDB): - async with pilot_db as pilot_db: - # Add pilots - refs = [f"ref_{i}" for i in range(10)] - stamps = [f"stamp_{i}" for i in range(10)] - pilot_references = dict(zip(stamps, refs)) - - await pilot_db.add_pilots( - stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references - ) - - # Accept duplicates because it is checked by the logic - await pilot_db.add_pilots( - stamps, MAIN_VO, grid_type="DIRAC", pilot_references=None - ) - - -@pytest.mark.asyncio -async def test_insert_and_delete(pilot_db: PilotAgentsDB): - async with pilot_db as pilot_db: - # Add pilots - refs = [f"ref_{i}" for i in range(2)] - stamps = [f"stamp_{i}" for i in range(2)] - pilot_references = dict(zip(stamps, refs)) - - await pilot_db.add_pilots( - stamps, MAIN_VO, grid_type="DIRAC", pilot_references=pilot_references - ) - - # Works, the pilots exists - res = await get_pilots_by_stamp(pilot_db, [stamps[0]]) - await get_pilots_by_stamp(pilot_db, [stamps[0]]) - - # We delete the first pilot - await pilot_db.delete_pilots([res[0]["PilotID"]]) - - # We get the 2nd pilot that is not delete (no error) - await get_pilots_by_stamp(pilot_db, [stamps[1]]) - # We get the 1st pilot that is delete (error) - - assert not await get_pilots_by_stamp(pilot_db, [stamps[0]]) - - -@pytest.mark.asyncio -async def test_insert_and_select_single_then_modify(pilot_db: PilotAgentsDB): - async with pilot_db as pilot_db: - pilot_stamp = "stamp-test" - await pilot_db.add_pilots( +async def test_register_pilots_roundtrip(pilot_db: PilotAgentsDB): + """Register a pilot and read it back with its defaults.""" + async with pilot_db as db: + await db.register_pilots( + pilot_stamps=["stamp-a"], vo=MAIN_VO, - pilot_stamps=[pilot_stamp], grid_type="grid-type", ) - res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) - assert len(res) == 1 - pilot = res[0] - - # Assert values + rows = await get_pilots_by_stamp(db, ["stamp-a"]) + assert len(rows) == 1 + pilot = rows[0] assert pilot["VO"] == MAIN_VO - assert pilot["PilotStamp"] == pilot_stamp + assert pilot["PilotStamp"] == "stamp-a" assert pilot["GridType"] == "grid-type" assert pilot["BenchMark"] == 0.0 assert pilot["Status"] == PilotStatus.SUBMITTED - assert pilot["StatusReason"] == "Unknown" assert not pilot["AccountingSent"] - # - # Modify a pilot, then check if every change is done - # - await pilot_db.update_pilot_fields( + +@pytest.mark.asyncio +async def test_update_pilot_metadata_heterogeneous(pilot_db: PilotAgentsDB): + """A single call must support rows that set different field subsets. + + Row 1 sets only Status. Row 2 sets only BenchMark. The CASE-based bulk + update must apply each subset correctly without raising. + """ + async with pilot_db as db: + await db.register_pilots( + pilot_stamps=["stamp-1", "stamp-2"], + vo=MAIN_VO, + grid_type="DIRAC", + ) + + await db.update_pilot_metadata( [ - PilotFieldsMapping( - PilotStamp=pilot_stamp, - BenchMark=1.0, - StatusReason="NewReason", - AccountingSent=True, - Status=PilotStatus.WAITING, - ) + PilotMetadata(PilotStamp="stamp-1", Status=PilotStatus.WAITING), + PilotMetadata(PilotStamp="stamp-2", BenchMark=42.0), ] ) - res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) - assert len(res) == 1 - pilot = res[0] + rows = { + r["PilotStamp"]: r + for r in await get_pilots_by_stamp(db, ["stamp-1", "stamp-2"]) + } + assert rows["stamp-1"]["Status"] == PilotStatus.WAITING + assert rows["stamp-1"]["BenchMark"] == 0.0 # untouched + assert rows["stamp-2"]["Status"] == PilotStatus.SUBMITTED # untouched + assert rows["stamp-2"]["BenchMark"] == 42.0 - # Set values - assert pilot["VO"] == MAIN_VO - assert pilot["PilotStamp"] == pilot_stamp - assert pilot["GridType"] == "grid-type" - assert pilot["BenchMark"] == 1.0 - assert pilot["Status"] == PilotStatus.WAITING - assert pilot["StatusReason"] == "NewReason" - assert pilot["AccountingSent"] + +@pytest.mark.asyncio +async def test_update_pilot_metadata_missing_pilot_raises(pilot_db: PilotAgentsDB): + async with pilot_db as db: + with pytest.raises(PilotNotFoundError): + await db.update_pilot_metadata( + [PilotMetadata(PilotStamp="nope", Status=PilotStatus.DONE)] + ) @pytest.mark.asyncio -async def test_associate_pilot_with_job_and_get_it(pilot_db: PilotAgentsDB): - """We will proceed in few steps. - - 1. Create a pilot - 2. Verify that he is not associated with any job - 3. Associate with jobs - 4. Verify that he is associate with this job - 5. Associate with jobs that he already has and two that he has not - 6. Associate with jobs that he has not, but were involved in a crash - """ - async with pilot_db as pilot_db: - pilot_stamp = "stamp-test" - # Add pilot - await pilot_db.add_pilots( - vo=MAIN_VO, - pilot_stamps=[pilot_stamp], - grid_type="grid-type", +async def test_assign_jobs_to_pilot_duplicate_raises(pilot_db: PilotAgentsDB): + """Second assignment of the same (pilot, job) pair must raise.""" + async with pilot_db as db: + await db.register_pilots(pilot_stamps=["stamp-x"], vo=MAIN_VO) + rows = await get_pilots_by_stamp(db, ["stamp-x"]) + pilot_id = rows[0]["PilotID"] + now = datetime.now(tz=timezone.utc) + + await db.assign_jobs_to_pilot( + [{"PilotID": pilot_id, "JobID": 1, "StartTime": now}] ) - res = await get_pilots_by_stamp(pilot_db, [pilot_stamp]) - assert len(res) == 1 - pilot = res[0] - pilot_id = pilot["PilotID"] + with pytest.raises(PilotAlreadyAssociatedWithJobError): + await db.assign_jobs_to_pilot( + [ + {"PilotID": pilot_id, "JobID": 1, "StartTime": now}, + {"PilotID": pilot_id, "JobID": 2, "StartTime": now}, + ] + ) - # Verify that he has no jobs - assert len(await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id)) == 0 +@pytest.mark.asyncio +async def test_delete_cascades_logs_and_mappings(pilot_db: PilotAgentsDB): + """Deleting a pilot removes its logs and job associations.""" + async with pilot_db as db: + await db.register_pilots(pilot_stamps=["stamp-z"], vo=MAIN_VO) + rows = await get_pilots_by_stamp(db, ["stamp-z"]) + pilot_id = rows[0]["PilotID"] now = datetime.now(tz=timezone.utc) + await db.assign_jobs_to_pilot( + [{"PilotID": pilot_id, "JobID": 100, "StartTime": now}] + ) - # Associate pilot with jobs - pilot_jobs = [1, 2, 3] - # Prepare the list of dictionaries for bulk insertion - job_to_pilot_mapping = [ - {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} - for job_id in pilot_jobs - ] - await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) - - # Verify that he has all jobs - db_jobs = await get_pilot_jobs_ids_by_pilot_id(pilot_db, pilot_id) - # We test both length and if every job is included if for any reason we have duplicates - assert all(job in db_jobs for job in pilot_jobs) - assert len(pilot_jobs) == len(db_jobs) - - # Associate pilot with a job that he already has, and one that he has not - pilot_jobs = [10, 1, 5] - with pytest.raises(PilotAlreadyAssociatedWithJobError): - # Prepare the list of dictionaries for bulk insertion - job_to_pilot_mapping = [ - {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} - for job_id in pilot_jobs - ] - await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) - - # Associate pilot with jobs that he has not, but was previously in an error - # To test that the rollback worked - pilot_jobs = [5, 10] - # Prepare the list of dictionaries for bulk insertion - job_to_pilot_mapping = [ - {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} - for job_id in pilot_jobs - ] - await pilot_db.add_jobs_to_pilot(job_to_pilot_mapping) + await db.remove_jobs_from_pilots([pilot_id]) + await db.delete_pilot_logs([pilot_id]) + await db.delete_pilots([pilot_id]) + + assert await get_pilots_by_stamp(db, ["stamp-z"]) == [] diff --git a/diracx-db/tests/pilots/test_query.py b/diracx-db/tests/pilots/test_query.py index 592329d8c..ddf0b9392 100644 --- a/diracx-db/tests/pilots/test_query.py +++ b/diracx-db/tests/pilots/test_query.py @@ -1,9 +1,9 @@ from __future__ import annotations import pytest +import pytest_asyncio -from diracx.core.exceptions import InvalidQueryError -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.pilot import PilotMetadata, PilotStatus from diracx.core.models.search import ( ScalarSearchOperator, ScalarSearchSpec, @@ -15,11 +15,11 @@ from diracx.db.sql.pilots.db import PilotAgentsDB MAIN_VO = "lhcb" -N = 100 +N = 20 -@pytest.fixture -async def pilot_db(tmp_path): +@pytest_asyncio.fixture +async def pilot_db(): agents_db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") async with agents_db.engine_context(): async with agents_db.engine.begin() as conn: @@ -27,273 +27,86 @@ async def pilot_db(tmp_path): yield agents_db -PILOT_REASONS = [ - "I was sick", - "I can't, I have a pony.", - "I was shopping", - "I was sleeping", -] - -PILOT_STATUSES = list(PilotStatus) - - -@pytest.fixture +@pytest_asyncio.fixture async def populated_pilot_db(pilot_db): - async with pilot_db as pilot_db: - # Add pilots - refs = [f"ref_{i + 1}" for i in range(N)] - stamps = [f"stamp_{i + 1}" for i in range(N)] - pilot_references = dict(zip(stamps, refs)) - - vo = MAIN_VO - - await pilot_db.add_pilots( - stamps, vo, grid_type="DIRAC", pilot_references=pilot_references + """Small, focused dataset: N pilots with alternating statuses.""" + stamps = [f"stamp_{i + 1}" for i in range(N)] + async with pilot_db as db: + await db.register_pilots( + pilot_stamps=stamps, + vo=MAIN_VO, + grid_type="DIRAC", + pilot_references={s: f"ref_{i + 1}" for i, s in enumerate(stamps)}, ) - - await pilot_db.update_pilot_fields( + await db.update_pilot_metadata( [ - PilotFieldsMapping( - PilotStamp=pilot_stamp, - BenchMark=i**2, - StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], - AccountingSent=True, - Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], - CurrentJobID=i, - Queue=f"queue_{i}", + PilotMetadata( + PilotStamp=stamp, + BenchMark=float(i), + Status=(PilotStatus.ABORTED if i % 2 == 0 else PilotStatus.WAITING), ) - for i, pilot_stamp in enumerate(stamps) + for i, stamp in enumerate(stamps) ] ) - yield pilot_db -async def test_search_parameters(populated_pilot_db): - """Test that we can search specific parameters for pilots in the database.""" - async with populated_pilot_db as pilot_db: - # Search a specific parameter: PilotID - total, result = await pilot_db.search_pilots(["PilotID"], [], []) - assert total == N - assert result - for r in result: - assert r.keys() == {"PilotID"} - - # Search a specific parameter: Status - total, result = await pilot_db.search_pilots(["Status"], [], []) - assert total == N - assert result - for r in result: - assert r.keys() == {"Status"} +@pytest.mark.asyncio +async def test_search_smoke(populated_pilot_db): + """Smoke-test the generic search helpers through the pilots table. - # Search for multiple parameters: PilotID, Status - total, result = await pilot_db.search_pilots(["PilotID", "Status"], [], []) + Deep search semantics (operators, pagination, sorting) are covered by + the shared base class tests; here we only verify the pilot-specific + search functions plug into the base correctly. + """ + async with populated_pilot_db as db: + # Select all, count and verify shape + total, result = await db.search_pilots(None, [], []) assert total == N - assert result - for r in result: - assert r.keys() == {"PilotID", "Status"} - - # Search for a specific parameter but use distinct: Status - total, result = await pilot_db.search_pilots(["Status"], [], [], distinct=True) - assert total == len(PILOT_STATUSES) - assert result - - # Search for a non-existent parameter: Dummy - with pytest.raises(InvalidQueryError): - total, result = await pilot_db.search_pilots(["Dummy"], [], []) + assert all("PilotID" in row for row in result) - -async def test_search_conditions(populated_pilot_db): - """Test that we can search for specific pilots in the database.""" - async with populated_pilot_db as pilot_db: - # Search a specific scalar condition: PilotID eq 3 - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert total == 1 - assert result - assert len(result) == 1 - assert result[0]["PilotID"] == 3 - - # Search a specific scalar condition: PilotID lt 3 - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert total == 2 - assert result - assert len(result) == 2 - assert result[0]["PilotID"] == 1 - assert result[1]["PilotID"] == 2 - - # Search a specific scalar condition: PilotID neq 3 - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert total == 99 - assert result - assert len(result) == 99 - assert all(r["PilotID"] != 3 for r in result) - - # Search a specific scalar condition: PilotID eq 5873 (does not exist) - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert not result - - # Search a specific vector condition: PilotID in 1,2,3 - condition = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] + # Filter by vector IN + total, result = await db.search_pilots( + ["PilotStamp"], + [ + VectorSearchSpec( + parameter="PilotStamp", + operator=VectorSearchOperator.IN, + values=["stamp_1", "stamp_2", "stamp_3"], + ) + ], + [], ) - total, result = await pilot_db.search_pilots([], [condition], []) assert total == 3 - assert result - assert len(result) == 3 - assert all(r["PilotID"] in [1, 2, 3] for r in result) - # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) - condition = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert total == 2 - assert result - assert len(result) == 2 - assert all(r["PilotID"] in [1, 2] for r in result) - - # Search a specific vector condition: PilotID not in 1,2,3 - condition = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert total == 97 - assert result - assert len(result) == 97 - assert all(r["PilotID"] not in [1, 2, 3] for r in result) - - # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) - condition = VectorSearchSpec( - parameter="PilotID", - operator=VectorSearchOperator.NOT_IN, - values=[1, 2, 5873], - ) - total, result = await pilot_db.search_pilots([], [condition], []) - assert total == 98 - assert result - assert len(result) == 98 - assert all(r["PilotID"] not in [1, 2] for r in result) - - # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 - condition1 = ScalarSearchSpec( - parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" - ) - condition2 = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + # Filter by scalar EQ on a mutable column + total, result = await db.search_pilots( + ["PilotStamp"], + [ + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ], + [], ) - total, result = await pilot_db.search_pilots([], [condition1, condition2], []) - assert total == 1 - assert result - assert len(result) == 1 - assert result[0]["PilotID"] == 5 - assert result[0]["PilotStamp"] == "stamp_5" + assert total == N // 2 - # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 - condition1 = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 + # Sort descending + _, result = await db.search_pilots( + ["PilotID"], + [], + [SortSpec(parameter="PilotID", direction=SortDirection.DESC)], ) - condition2 = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] - ) - total, result = await pilot_db.search_pilots([], [condition1, condition2], []) - assert total == 0 - assert not result - + assert [r["PilotID"] for r in result][0] == N -async def test_search_sorts(populated_pilot_db): - """Test that we can search for pilots in the database and sort the results.""" - async with populated_pilot_db as pilot_db: - # Search and sort by PilotID in ascending order - sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) - total, result = await pilot_db.search_pilots([], [], [sort]) - assert total == N - assert result - for i, r in enumerate(result): - assert r["PilotID"] == i + 1 - - # Search and sort by PilotID in descending order - sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) - total, result = await pilot_db.search_pilots([], [], [sort]) - assert total == N - assert result - for i, r in enumerate(result): - assert r["PilotID"] == N - i - # Search and sort by PilotStamp in ascending order - sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) - total, result = await pilot_db.search_pilots([], [], [sort]) - assert total == N - assert result - # Assert that stamp_10 is before stamp_2 because of the lexicographical order - assert result[2]["PilotStamp"] == "stamp_100" - assert result[12]["PilotStamp"] == "stamp_2" - - # Search and sort by PilotStamp in descending order - sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) - total, result = await pilot_db.search_pilots([], [], [sort]) - assert total == N - assert result - # Assert that stamp_10 is before stamp_2 because of the lexicographical order - assert result[97]["PilotStamp"] == "stamp_100" - assert result[87]["PilotStamp"] == "stamp_2" - - # Search and sort by PilotStamp in ascending order and PilotID in descending order - sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) - sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) - total, result = await pilot_db.search_pilots([], [], [sort1, sort2]) - assert total == N - assert result - assert result[0]["PilotStamp"] == "stamp_1" - assert result[0]["PilotID"] == 1 - assert result[99]["PilotStamp"] == "stamp_99" - assert result[99]["PilotID"] == 99 - - -@pytest.mark.parametrize( - "per_page, page, expected_len, expected_first_id, expect_exception", - [ - (10, 1, 10, 1, None), # Page 1 - (10, 2, 10, 11, None), # Page 2 - (10, 10, 10, 91, None), # Page 10 - (50, 2, 50, 51, None), # Page 2 with 50 per page - (10, 11, 0, None, None), # Page beyond range, should return empty - (10, 0, None, None, InvalidQueryError), # Invalid page - (0, 1, None, None, InvalidQueryError), # Invalid per_page - ], -) -async def test_search_pagination( - populated_pilot_db, - per_page, - page, - expected_len, - expected_first_id, - expect_exception, -): - """Test pagination logic in pilot search.""" - async with populated_pilot_db as pilot_db: - if expect_exception: - with pytest.raises(expect_exception): - await pilot_db.search_pilots([], [], [], per_page=per_page, page=page) - else: - total, result = await pilot_db.search_pilots( - [], [], [], per_page=per_page, page=page - ) - assert total == N - if expected_len == 0: - assert not result - else: - assert result - assert len(result) == expected_len - assert result[0]["PilotID"] == expected_first_id +@pytest.mark.asyncio +async def test_pilot_summary_groups_by_status(populated_pilot_db): + """`pilot_summary` must aggregate by the requested column.""" + async with populated_pilot_db as db: + rows = await db.pilot_summary(group_by=["Status"], search=[]) + by_status = {r["Status"]: r["count"] for r in rows} + assert by_status[PilotStatus.ABORTED] == N // 2 + assert by_status[PilotStatus.WAITING] == N // 2 diff --git a/diracx-db/tests/pilots/utils.py b/diracx-db/tests/pilots/utils.py index d803a3597..b1cff2a10 100644 --- a/diracx-db/tests/pilots/utils.py +++ b/diracx-db/tests/pilots/utils.py @@ -1,29 +1,20 @@ from __future__ import annotations -from datetime import datetime, timezone from typing import Any -import pytest -from sqlalchemy import update - from diracx.core.models.search import ( - ScalarSearchOperator, - ScalarSearchSpec, VectorSearchOperator, VectorSearchSpec, ) from diracx.db.sql.pilots.db import PilotAgentsDB -from diracx.db.sql.pilots.schema import PilotAgents - -MAIN_VO = "lhcb" -N = 100 - -# ------------ Fetching data ------------ async def get_pilots_by_stamp( - pilot_db: PilotAgentsDB, pilot_stamps: list[str], parameters: list[str] = [] + pilot_db: PilotAgentsDB, + pilot_stamps: list[str], + parameters: list[str] | None = None, ) -> list[dict[Any, Any]]: + """Test helper: fetch pilots by stamp directly from the DB layer.""" _, pilots = await pilot_db.search_pilots( parameters=parameters, search=[ @@ -37,115 +28,4 @@ async def get_pilots_by_stamp( distinct=True, per_page=1000, ) - return pilots - - -async def get_pilot_jobs_ids_by_pilot_id( - pilot_db: PilotAgentsDB, pilot_id: int -) -> list[int]: - _, jobs = await pilot_db.search_pilot_to_job_mapping( - parameters=["JobID"], - search=[ - ScalarSearchSpec( - parameter="PilotID", - operator=ScalarSearchOperator.EQUAL, - value=pilot_id, - ) - ], - sorts=[], - distinct=True, - per_page=10000, - ) - - return [job["JobID"] for job in jobs] - - -# ------------ Creating data ------------ - - -@pytest.fixture -async def add_stamps(pilot_db): - async def _add_stamps(start_n=0): - async with pilot_db as db: - # Add pilots - refs = [f"ref_{i}" for i in range(start_n, start_n + N)] - stamps = [f"stamp_{i}" for i in range(start_n, start_n + N)] - pilot_references = dict(zip(stamps, refs)) - - vo = MAIN_VO - - await db.add_pilots( - stamps, vo, grid_type="DIRAC", pilot_references=pilot_references - ) - - return await get_pilots_by_stamp(db, stamps) - - return _add_stamps - - -@pytest.fixture -async def create_timed_pilots(pilot_db, add_stamps): - async def _create_timed_pilots( - old_date: datetime, aborted: bool = False, start_n=0 - ): - # Get pilots - pilots = await add_stamps(start_n) - - async with pilot_db as db: - # Update manually their age - # Collect PilotStamps - pilot_stamps = [pilot["PilotStamp"] for pilot in pilots] - - stmt = ( - update(PilotAgents) - .where(PilotAgents.pilot_stamp.in_(pilot_stamps)) - .values(SubmissionTime=old_date) - ) - - if aborted: - stmt = stmt.values(Status="Aborted") - - res = await db.conn.execute(stmt) - assert res.rowcount == len(pilot_stamps) - - pilots = await get_pilots_by_stamp(db, pilot_stamps) - return pilots - - return _create_timed_pilots - - -@pytest.fixture -async def create_old_pilots_environment(pilot_db, create_timed_pilots): - non_aborted_recent = await create_timed_pilots( - datetime(2025, 1, 1, tzinfo=timezone.utc), False, N - ) - aborted_recent = await create_timed_pilots( - datetime(2025, 1, 1, tzinfo=timezone.utc), True, 2 * N - ) - - aborted_very_old = await create_timed_pilots( - datetime(2003, 3, 10, tzinfo=timezone.utc), True, 3 * N - ) - non_aborted_very_old = await create_timed_pilots( - datetime(2003, 3, 10, tzinfo=timezone.utc), False, 4 * N - ) - - pilot_number = 4 * N - - assert pilot_number == ( - len(non_aborted_recent) - + len(aborted_recent) - + len(aborted_very_old) - + len(non_aborted_very_old) - ) - - # Phase 0. Verify that we have the right environment - async with pilot_db as pilot_db: - # Ensure that we can get every pilot (only get first of each group) - await get_pilots_by_stamp(pilot_db, [non_aborted_recent[0]["PilotStamp"]]) - await get_pilots_by_stamp(pilot_db, [aborted_recent[0]["PilotStamp"]]) - await get_pilots_by_stamp(pilot_db, [aborted_very_old[0]["PilotStamp"]]) - await get_pilots_by_stamp(pilot_db, [non_aborted_very_old[0]["PilotStamp"]]) - - return non_aborted_recent, aborted_recent, non_aborted_very_old, aborted_very_old diff --git a/diracx-logic/src/diracx/logic/jobs/query.py b/diracx-logic/src/diracx/logic/jobs/query.py index a3e1a9732..0968c34be 100644 --- a/diracx-logic/src/diracx/logic/jobs/query.py +++ b/diracx-logic/src/diracx/logic/jobs/query.py @@ -4,33 +4,114 @@ from typing import Any from diracx.core.config.schema import Config +from diracx.core.exceptions import InvalidQueryError from diracx.core.models.search import ( ScalarSearchOperator, SearchParams, SummaryParams, + VectorSearchOperator, + VectorSearchSpec, ) from diracx.db.os.job_parameters import JobParametersDB from diracx.db.sql.job.db import JobDB from diracx.db.sql.job_logging.db import JobLoggingDB +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.logic.pilots.query import resolve_jobs_for_pilot_stamps logger = logging.getLogger(__name__) MAX_PER_PAGE = 10000 +# Pseudo-parameter accepted on POST /api/jobs/search. Resolves to a +# JobID IN (...) filter via JobToPilotMapping. +PILOT_STAMP_PSEUDO_PARAM = "PilotStamp" +# Real Jobs column that PilotStamp would collide with if both were +# accepted in the same request body. +JOB_ID_REAL_PARAM = "JobID" + + +async def _rewrite_pilot_stamp_pseudo_param( + pilot_db: PilotAgentsDB, body: SearchParams +) -> bool: + """Rewrite any `PilotStamp` pseudo-parameter in `body.search`. + + Collects every `PilotStamp` filter, resolves them through + `JobToPilotMapping`, removes the originals from `body.search`, and + appends a single `JobID IN (...)` vector filter. Returns `True` + if the resolution produced an empty list (the caller should + short-circuit to an empty result), `False` otherwise. + + Supports `eq` and `in` operators only; every other operator raises + `InvalidQueryError` because the join semantics are ambiguous. + Combining a `PilotStamp` pseudo-filter with a real `JobID` filter + in the same body is also refused. + """ + matches = [ + spec + for spec in body.search + if spec.get("parameter") == PILOT_STAMP_PSEUDO_PARAM + ] + if not matches: + return False + + if any(spec.get("parameter") == JOB_ID_REAL_PARAM for spec in body.search): + raise InvalidQueryError( + f"Cannot combine {PILOT_STAMP_PSEUDO_PARAM!r} pseudo-parameter " + f"with a real {JOB_ID_REAL_PARAM!r} filter in the same request." + ) + + stamps: list[str] = [] + for spec in matches: + operator = spec.get("operator") + if operator == ScalarSearchOperator.EQUAL: + stamps.append(str(spec["value"])) # type: ignore[typeddict-item] + elif operator == VectorSearchOperator.IN: + stamps.extend(str(v) for v in spec["values"]) # type: ignore[typeddict-item] + else: + raise InvalidQueryError( + f"Operator {operator!r} is not supported on the " + f"{PILOT_STAMP_PSEUDO_PARAM!r} pseudo-parameter; " + "use 'eq' or 'in'." + ) + + job_ids = await resolve_jobs_for_pilot_stamps(pilot_db, stamps) + body.search = [ + spec + for spec in body.search + if spec.get("parameter") != PILOT_STAMP_PSEUDO_PARAM + ] + if not job_ids: + return True + body.search.append( + VectorSearchSpec( + parameter=JOB_ID_REAL_PARAM, + operator=VectorSearchOperator.IN, + values=job_ids, + ) + ) + return False + async def search( config: Config, job_db: JobDB, job_parameters_db: JobParametersDB, job_logging_db: JobLoggingDB, + pilot_db: PilotAgentsDB, preferred_username: str | None, vo: str, page: int = 1, per_page: int = 100, body: SearchParams | None = None, ) -> tuple[int, list[dict[str, Any]]]: - """Retrieve information about jobs.""" + """Retrieve information about jobs. + + Accepts a `PilotStamp` pseudo-parameter in `body.search` + (`eq`/`in` only): it is resolved through `JobToPilotMapping` into + a concrete `JobID` vector filter before the main query runs. Mirrors + the `JobID` pseudo-parameter on `POST /api/pilots/search`. + """ # Apply a limit to per_page to prevent abuse of the API if per_page > MAX_PER_PAGE: per_page = MAX_PER_PAGE @@ -38,6 +119,10 @@ async def search( if body is None: body = SearchParams() + empty_after_rewrite = await _rewrite_pilot_stamp_pseudo_param(pilot_db, body) + if empty_after_rewrite: + return 0, [] + if query_logging_info := ("LoggingInfo" in (body.parameters or [])): if body.parameters: body.parameters.remove("LoggingInfo") diff --git a/diracx-logic/src/diracx/logic/pilots/management.py b/diracx-logic/src/diracx/logic/pilots/management.py index 5698b447e..6d557debf 100644 --- a/diracx-logic/src/diracx/logic/pilots/management.py +++ b/diracx-logic/src/diracx/logic/pilots/management.py @@ -1,17 +1,18 @@ from __future__ import annotations from datetime import datetime, timedelta, timezone +from typing import Any from diracx.core.exceptions import PilotAlreadyExistsError, PilotNotFoundError -from diracx.core.models.pilot import PilotFieldsMapping +from diracx.core.models.pilot import PilotMetadata, PilotStatus +from diracx.core.models.search import ( + ScalarSearchOperator, + ScalarSearchSpec, + SearchSpec, +) from diracx.db.sql import PilotAgentsDB -from .query import ( - get_outdated_pilots, - get_pilot_ids_by_stamps, - get_pilot_jobs_ids_by_pilot_id, - get_pilots_by_stamp, -) +from .query import get_pilots_by_stamp async def register_new_pilots( @@ -24,21 +25,27 @@ async def register_new_pilots( status: str, pilot_job_references: dict[str, str] | None, ): - # [IMPORTANT] Check unicity of pilot stamps - # If a pilot already exists, we raise an error (transaction will rollback) + """Register a batch of new pilots. + + Raises `PilotAlreadyExistsError` if any stamp already exists. + + Uniqueness is best-effort: the DIRAC `PilotAgents` schema has no unique + constraint on `PilotStamp` (only a non-unique key), so a concurrent + registration of the same stamp from two processes could race past this + check. In practice pilot stamps are cryptographically random UUIDs, + making the collision window negligible. + """ existing_pilots = await get_pilots_by_stamp( pilot_db=pilot_db, pilot_stamps=pilot_stamps ) - # If we found pilots from the list, this means some pilots already exist - if len(existing_pilots) > 0: + if existing_pilots: found_keys = {pilot["PilotStamp"] for pilot in existing_pilots} - raise PilotAlreadyExistsError( f"The following pilots already exist: {found_keys}" ) - await pilot_db.add_pilots( + await pilot_db.register_pilots( pilot_stamps=pilot_stamps, vo=vo, grid_type=grid_type, @@ -51,72 +58,119 @@ async def register_new_pilots( async def delete_pilots( pilot_db: PilotAgentsDB, + *, pilot_stamps: list[str] | None = None, age_in_days: int | None = None, delete_only_aborted: bool = True, vo_constraint: str | None = None, ): - if pilot_stamps: - pilot_ids = await get_pilot_ids_by_stamps( - pilot_db=pilot_db, pilot_stamps=pilot_stamps, allow_missing=True - ) - else: - assert age_in_days - assert vo_constraint + """Delete pilots by stamps or by age. - cutoff_date = datetime.now(tz=timezone.utc) - timedelta(days=age_in_days) + Exactly one of `pilot_stamps` or `age_in_days` must be provided. - pilots = await get_outdated_pilots( + The age-based branch is used by the maintenance task worker (not exposed + on the public router). `vo_constraint` scopes an age-based deletion to + a single VO; pass `None` for cross-VO cleanup. + """ + if pilot_stamps is not None: + pilots = await get_pilots_by_stamp( pilot_db=pilot_db, - cutoff_date=cutoff_date, - only_aborted=delete_only_aborted, + pilot_stamps=pilot_stamps, parameters=["PilotID"], + ) + pilot_ids = [p["PilotID"] for p in pilots] + elif age_in_days is not None: + pilot_ids = await _list_pilots_for_age_cleanup( + pilot_db=pilot_db, + age_in_days=age_in_days, + delete_only_aborted=delete_only_aborted, vo_constraint=vo_constraint, ) + else: + raise ValueError("Exactly one of pilot_stamps or age_in_days must be provided.") - pilot_ids = [pilot["PilotID"] for pilot in pilots] + if not pilot_ids: + return await pilot_db.remove_jobs_from_pilots(pilot_ids) await pilot_db.delete_pilot_logs(pilot_ids) await pilot_db.delete_pilots(pilot_ids) -async def update_pilots_fields( - pilot_db: PilotAgentsDB, pilot_stamps_to_fields_mapping: list[PilotFieldsMapping] +async def _list_pilots_for_age_cleanup( + pilot_db: PilotAgentsDB, + age_in_days: int, + delete_only_aborted: bool, + vo_constraint: str | None, +) -> list[int]: + """Return pilot IDs older than `age_in_days`. + + Internal helper for age-based cleanup. The cutoff is compared server-side + via the search layer; the datetime is serialised as an ISO-8601 string to + avoid widening the search-spec type for this one caller. + """ + cutoff = (datetime.now(tz=timezone.utc) - timedelta(days=age_in_days)).isoformat() + + search: list[SearchSpec] = [ + ScalarSearchSpec( + parameter="SubmissionTime", + operator=ScalarSearchOperator.LESS_THAN, + value=cutoff, + ), + ] + if vo_constraint is not None: + search.append( + ScalarSearchSpec( + parameter="VO", + operator=ScalarSearchOperator.EQUAL, + value=vo_constraint, + ) + ) + if delete_only_aborted: + search.append( + ScalarSearchSpec( + parameter="Status", + operator=ScalarSearchOperator.EQUAL, + value=PilotStatus.ABORTED, + ) + ) + + _, pilots = await pilot_db.search_pilots( + parameters=["PilotID"], + search=search, + sorts=[], + ) + return [p["PilotID"] for p in pilots] + + +async def update_pilots_metadata( + pilot_db: PilotAgentsDB, + pilot_metadata: list[PilotMetadata], ): - await pilot_db.update_pilot_fields(pilot_stamps_to_fields_mapping) + """Bulk-update pilot metadata.""" + await pilot_db.update_pilot_metadata(pilot_metadata) -async def add_jobs_to_pilot( +async def assign_jobs_to_pilot( pilot_db: PilotAgentsDB, pilot_stamp: str, job_ids: list[int] ): - pilot_ids = await get_pilot_ids_by_stamps( - pilot_db=pilot_db, pilot_stamps=[pilot_stamp] + """Associate jobs with a pilot identified by its stamp.""" + pilots = await get_pilots_by_stamp( + pilot_db=pilot_db, + pilot_stamps=[pilot_stamp], + parameters=["PilotID"], ) - pilot_id = pilot_ids[0] - - now = datetime.now(tz=timezone.utc) - - # Prepare the list of dictionaries for bulk insertion - job_to_pilot_mapping = [ - {"PilotID": pilot_id, "JobID": job_id, "StartTime": now} for job_id in job_ids + if not pilots: + raise PilotNotFoundError(detail=f"pilot {pilot_stamp!r} does not exist") + pilot_id = pilots[0]["PilotID"] + + job_to_pilot_mapping: list[dict[str, Any]] = [ + { + "PilotID": pilot_id, + "JobID": job_id, + "StartTime": datetime.now(tz=timezone.utc), + } + for job_id in job_ids ] - await pilot_db.add_jobs_to_pilot( - job_to_pilot_mapping=job_to_pilot_mapping, - ) - - -async def get_pilot_jobs_ids_by_stamp( - pilot_db: PilotAgentsDB, pilot_stamp: str -) -> list[int]: - """Fetch pilot jobs by stamp.""" - try: - pilot_ids = await get_pilot_ids_by_stamps( - pilot_db=pilot_db, pilot_stamps=[pilot_stamp] - ) - pilot_id = pilot_ids[0] - except PilotNotFoundError: - return [] - - return await get_pilot_jobs_ids_by_pilot_id(pilot_db=pilot_db, pilot_id=pilot_id) + await pilot_db.assign_jobs_to_pilot(job_to_pilot_mapping=job_to_pilot_mapping) diff --git a/diracx-logic/src/diracx/logic/pilots/query.py b/diracx-logic/src/diracx/logic/pilots/query.py index edad8f7ec..090aa3101 100644 --- a/diracx-logic/src/diracx/logic/pilots/query.py +++ b/diracx-logic/src/diracx/logic/pilots/query.py @@ -1,15 +1,12 @@ from __future__ import annotations -from datetime import datetime from typing import Any -from diracx.core.exceptions import PilotNotFoundError -from diracx.core.models.pilot import PilotStatus +from diracx.core.exceptions import InvalidQueryError from diracx.core.models.search import ( ScalarSearchOperator, ScalarSearchSpec, SearchParams, - SearchSpec, SummaryParams, VectorSearchOperator, VectorSearchSpec, @@ -18,29 +15,138 @@ MAX_PER_PAGE = 10000 +# Pseudo-parameter accepted on POST /api/pilots/search. Resolves to a +# PilotID IN (...) filter via JobToPilotMapping. +JOB_ID_PSEUDO_PARAM = "JobID" +# Real column on PilotAgents that JobID would collide with if both +# were accepted in the same request body. +PILOT_ID_REAL_PARAM = "PilotID" + + +def _add_vo_constraint( + body: SearchParams | SummaryParams, vo_constraint: str | None +) -> None: + """Add a VO filter to the search body if a constraint is supplied. + + Admin callers pass `vo_constraint=None` to bypass the filter and query + across all VOs. Mirrors the intra-VO pattern of `logic/jobs/query.py`. + """ + if vo_constraint is None: + return + body.search.append( + ScalarSearchSpec( + parameter="VO", + operator=ScalarSearchOperator.EQUAL, + value=vo_constraint, + ) + ) + + +async def resolve_jobs_for_pilot_stamps( + pilot_db: PilotAgentsDB, pilot_stamps: list[str] +) -> list[int]: + """Resolve a batch of pilot stamps to the job IDs they have run. + + Used by `logic/jobs/query.py:search` to rewrite the `PilotStamp` + pseudo-parameter into a concrete `JobID` vector filter. + """ + return await pilot_db.job_ids_for_stamps(pilot_stamps) + + +async def _resolve_pilots_for_job_ids( + pilot_db: PilotAgentsDB, job_ids: list[int] +) -> list[int]: + """Resolve a batch of job IDs to the pilot IDs that have run them.""" + return await pilot_db.pilot_ids_for_job_ids(job_ids) + + +async def _rewrite_job_id_pseudo_param( + pilot_db: PilotAgentsDB, body: SearchParams +) -> bool: + """Rewrite any `JobID` pseudo-parameter in `body.search`. + + Collects every `JobID` filter, resolves them through + `JobToPilotMapping`, removes the originals from `body.search`, and + appends a single `PilotID IN (...)` vector filter. Returns `True` + if the resolution produced an empty list (in which case the caller + should short-circuit to an empty result), `False` otherwise. + + Supports `eq` and `in` operators only; every other operator raises + `InvalidQueryError` because the join semantics are ambiguous. + Combining a `JobID` pseudo-filter with a real `PilotID` filter in + the same body is also refused. + """ + matches = [ + spec for spec in body.search if spec.get("parameter") == JOB_ID_PSEUDO_PARAM + ] + if not matches: + return False + + if any(spec.get("parameter") == PILOT_ID_REAL_PARAM for spec in body.search): + raise InvalidQueryError( + f"Cannot combine {JOB_ID_PSEUDO_PARAM!r} pseudo-parameter with a " + f"real {PILOT_ID_REAL_PARAM!r} filter in the same request." + ) + + job_ids: list[int] = [] + for spec in matches: + operator = spec.get("operator") + if operator == ScalarSearchOperator.EQUAL: + job_ids.append(int(spec["value"])) # type: ignore[typeddict-item] + elif operator == VectorSearchOperator.IN: + job_ids.extend(int(v) for v in spec["values"]) # type: ignore[typeddict-item] + else: + raise InvalidQueryError( + f"Operator {operator!r} is not supported on the " + f"{JOB_ID_PSEUDO_PARAM!r} pseudo-parameter; use 'eq' or 'in'." + ) + + pilot_ids = await _resolve_pilots_for_job_ids(pilot_db, job_ids) + body.search = [ + spec for spec in body.search if spec.get("parameter") != JOB_ID_PSEUDO_PARAM + ] + if not pilot_ids: + return True + body.search.append( + VectorSearchSpec( + parameter=PILOT_ID_REAL_PARAM, + operator=VectorSearchOperator.IN, + values=pilot_ids, + ) + ) + return False + async def search( pilot_db: PilotAgentsDB, - user_vo: str, + vo_constraint: str | None, page: int = 1, per_page: int = 100, body: SearchParams | None = None, ) -> tuple[int, list[dict[str, Any]]]: - """Retrieve information about jobs.""" - # Apply a limit to per_page to prevent abuse of the API + """Retrieve information about pilots. + + `vo_constraint` restricts results to a single VO; pass `None` to + query across VOs (reserved for service administrators). + + Accepts a `JobID` pseudo-parameter in `body.search` (`eq`/`in` + only): it is resolved through `JobToPilotMapping` into a concrete + `PilotID` vector filter before the main query runs. Mirrors the + `PilotStamp` pseudo-parameter on `POST /api/jobs/search`. + """ if per_page > MAX_PER_PAGE: per_page = MAX_PER_PAGE if body is None: body = SearchParams() - body.search.append( - ScalarSearchSpec( - parameter="VO", operator=ScalarSearchOperator.EQUAL, value=user_vo - ) - ) + empty_after_rewrite = await _rewrite_job_id_pseudo_param(pilot_db, body) + if empty_after_rewrite: + return 0, [] - total, pilots = await pilot_db.search_pilots( + _add_vo_constraint(body, vo_constraint) + + return await pilot_db.search_pilots( body.parameters, body.search, body.sort, @@ -49,24 +155,38 @@ async def search( per_page=per_page, ) - return total, pilots + +async def summary( + pilot_db: PilotAgentsDB, + body: SummaryParams, + vo_constraint: str | None, +): + """Aggregate pilot counts suitable for plotting.""" + _add_vo_constraint(body, vo_constraint) + return await pilot_db.pilot_summary(body.grouping, body.search) async def get_pilots_by_stamp( pilot_db: PilotAgentsDB, pilot_stamps: list[str], - parameters: list[str] = [], - allow_missing: bool = True, -) -> list[dict[Any, Any]]: - """Get pilots by their stamp. - - If `allow_missing` is set to False, if a pilot is missing, PilotNotFoundError will be raised. + parameters: list[str] | None = None, +) -> list[dict[str, Any]]: + """Return the pilots whose stamp is in `pilot_stamps`. + + Missing stamps are silently omitted from the result. Callers that care + about completeness must compare the returned length to the input. + `PilotStamp` is always included in the returned parameters so callers + can identify which stamps were found. """ - if parameters: - parameters.append("PilotStamp") + if parameters is None: + query_parameters: list[str] | None = None + else: + query_parameters = list(parameters) + if "PilotStamp" not in query_parameters: + query_parameters.append("PilotStamp") _, pilots = await pilot_db.search_pilots( - parameters=parameters, + parameters=query_parameters, search=[ VectorSearchSpec( parameter="PilotStamp", @@ -75,117 +195,6 @@ async def get_pilots_by_stamp( ) ], sorts=[], - distinct=True, per_page=MAX_PER_PAGE, ) - - # allow_missing is set as True by default to mark explicitly when we allow or not - if not allow_missing: - # Custom handling, to see which pilot_stamp does not exist (if so, say which one) - found_keys = {row["PilotStamp"] for row in pilots} - missing = set(pilot_stamps) - found_keys - - if missing: - raise PilotNotFoundError( - detail=str(missing), - ) - - return pilots - - -async def get_pilot_ids_by_stamps( - pilot_db: PilotAgentsDB, pilot_stamps: list[str], allow_missing=False -) -> list[int]: - pilots = await get_pilots_by_stamp( - pilot_db=pilot_db, - pilot_stamps=pilot_stamps, - parameters=["PilotID"], - allow_missing=allow_missing, - ) - - return [pilot["PilotID"] for pilot in pilots] - - -async def get_pilot_jobs_ids_by_pilot_id( - pilot_db: PilotAgentsDB, pilot_id: int -) -> list[int]: - _, jobs = await pilot_db.search_pilot_to_job_mapping( - parameters=["JobID"], - search=[ - ScalarSearchSpec( - parameter="PilotID", - operator=ScalarSearchOperator.EQUAL, - value=pilot_id, - ) - ], - sorts=[], - distinct=True, - per_page=MAX_PER_PAGE, - ) - - return [job["JobID"] for job in jobs] - - -async def get_pilot_ids_by_job_id(pilot_db: PilotAgentsDB, job_id: int) -> list[int]: - _, pilots = await pilot_db.search_pilot_to_job_mapping( - parameters=["PilotID"], - search=[ - ScalarSearchSpec( - parameter="JobID", - operator=ScalarSearchOperator.EQUAL, - value=job_id, - ) - ], - sorts=[], - distinct=True, - per_page=MAX_PER_PAGE, - ) - - return [pilot["PilotID"] for pilot in pilots] - - -async def get_outdated_pilots( - pilot_db: PilotAgentsDB, - cutoff_date: datetime, - vo_constraint: str, - only_aborted: bool = True, - parameters: list[str] = [], -): - query: list[SearchSpec] = [ - ScalarSearchSpec( - parameter="SubmissionTime", - operator=ScalarSearchOperator.LESS_THAN, - value=cutoff_date, - ), - # Add VO to avoid deleting other VO's pilots - ScalarSearchSpec( - parameter="VO", operator=ScalarSearchOperator.EQUAL, value=vo_constraint - ), - ] - - if only_aborted: - query.append( - ScalarSearchSpec( - parameter="Status", - operator=ScalarSearchOperator.EQUAL, - value=PilotStatus.ABORTED, - ) - ) - - _, pilots = await pilot_db.search_pilots( - parameters=parameters, search=query, sorts=[] - ) - return pilots - - -async def summary(pilot_db: PilotAgentsDB, body: SummaryParams, vo: str): - """Show information suitable for plotting.""" - body.search.append( - { - "parameter": "VO", - "operator": ScalarSearchOperator.EQUAL, - "value": vo, - } - ) - return await pilot_db.pilot_summary(body.grouping, body.search) diff --git a/diracx-routers/src/diracx/routers/jobs/query.py b/diracx-routers/src/diracx/routers/jobs/query.py index 487334601..29f6ce263 100644 --- a/diracx-routers/src/diracx/routers/jobs/query.py +++ b/diracx-routers/src/diracx/routers/jobs/query.py @@ -11,7 +11,7 @@ ) from diracx.core.properties import JOB_ADMINISTRATOR from diracx.db.os import JobParametersDB -from diracx.db.sql import JobDB, JobLoggingDB +from diracx.db.sql import JobDB, JobLoggingDB, PilotAgentsDB from diracx.logic.jobs.query import MAX_PER_PAGE from diracx.logic.jobs.query import search as search_bl from diracx.logic.jobs.query import summary as summary_bl @@ -45,6 +45,20 @@ "sort": [{"parameter": "JobID", "direction": "asc"}], }, }, + "Jobs run on a given pilot": { + "summary": "Jobs run on a given pilot", + "description": ( + "Find all jobs that have run on a specific pilot. `PilotStamp` " + "is a pseudo-parameter resolved through `JobToPilotMapping` " + "into a `JobID` filter; only `eq` and `in` operators are " + "supported." + ), + "value": { + "search": [ + {"parameter": "PilotStamp", "operator": "eq", "value": "abc-123"} + ] + }, + }, } @@ -123,6 +137,7 @@ async def search( job_db: JobDB, job_parameters_db: JobParametersDB, job_logging_db: JobLoggingDB, + pilot_db: PilotAgentsDB, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], check_permissions: CheckWMSPolicyCallable, response: Response, @@ -143,6 +158,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + + A `PilotStamp` pseudo-parameter is also accepted in the `search` + filter list (operators `eq` / `in` only): it is transparently + resolved through `JobToPilotMapping` into a `JobID` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. """ await check_permissions(action=ActionType.QUERY, job_db=job_db) @@ -155,6 +176,7 @@ async def search( job_db=job_db, job_parameters_db=job_parameters_db, job_logging_db=job_logging_db, + pilot_db=pilot_db, preferred_username=preferred_username, vo=user_info.vo, page=page, diff --git a/diracx-routers/src/diracx/routers/pilots/access_policies.py b/diracx-routers/src/diracx/routers/pilots/access_policies.py index ecacd2710..f28ff9b06 100644 --- a/diracx-routers/src/diracx/routers/pilots/access_policies.py +++ b/diracx-routers/src/diracx/routers/pilots/access_policies.py @@ -16,17 +16,21 @@ class ActionType(StrEnum): - # Change some pilot fields + # Change pilot metadata (status, fields, etc.). Admin-only by default; + # legacy pilot X.509 identities can be allowed via `allow_legacy_pilots`. MANAGE_PILOTS = auto() - # Read some pilot info - READ_PILOT_FIELDS = auto() + # Read pilot metadata. Normal users can read their own VO's pilots; + # `SERVICE_ADMINISTRATOR` can read across VOs. + READ_PILOT_METADATA = auto() class PilotManagementAccessPolicy(BaseAccessPolicy): """Pilot management access policy. - * Every user can access data about his VO - * An administrator can modify a pilot. + * Every user can read pilots from their own VO. + * Service administrators can read across VOs and manage pilots. + * Legacy X.509 pilot identities may be allowed to manage themselves when + `allow_legacy_pilots=True` is passed by the route. """ @staticmethod @@ -42,82 +46,81 @@ async def policy( job_ids: list[int] | None = None, allow_legacy_pilots: bool = False, ): - assert action, "action is a mandatory parameter" + # Authorization is VO-scoped, not bound to the caller's + # own pilot stamp. This mirrors DIRAC's PilotManagerHandler, which has + # no ownership check either. + if action is None: + raise ValueError("action is a mandatory parameter") - # Users can query - # NOTE: Add into queries a VO constraint - # To manage pilots, user have to be an admin - # In some special cases (described with allow_legacy_pilots), we can allow pilots if action == ActionType.MANAGE_PILOTS: - # To make it clear, we separate - is_an_admin = SERVICE_ADMINISTRATOR in user_info.properties - is_a_pilot_if_allowed = ( + is_admin = SERVICE_ADMINISTRATOR in user_info.properties + is_legacy_pilot = ( allow_legacy_pilots and GENERIC_PILOT in user_info.properties ) - - if not is_an_admin and not is_a_pilot_if_allowed: + if not is_admin and not is_legacy_pilot: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the permission to manage pilots.", + detail="Insufficient permissions to manage pilots.", ) - if action == ActionType.READ_PILOT_FIELDS: + if action == ActionType.READ_PILOT_METADATA: if GENERIC_PILOT in user_info.properties: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="Pilots can't read other pilots info.", + detail="Pilots cannot read other pilots' metadata.", ) - # - # Additional checks if job_ids or pilot_stamps are provided - # - - # First, if job_ids are provided, we check who is the owner - if job_db and job_ids: - job_owners = await job_db.summary( - ["Owner", "VO"], - [ + # If job IDs are provided, verify the user owns all of them. + # Using a direct search (rather than summary/aggregate equality) is + # clearer and gives a distinct 404 vs 403 on missing jobs. + if job_db is not None and job_ids: + _, owner_rows = await job_db.search( + parameters=["Owner", "VO"], + search=[ VectorSearchSpec( parameter="JobID", operator=VectorSearchOperator.IN, values=job_ids, ) ], + sorts=[], + per_page=len(set(job_ids)), ) - - expected_owner = { - "Owner": user_info.preferred_username, - "VO": user_info.vo, - "count": len(set(job_ids)), - } - # All the jobs belong to the user doing the query - # and all of them are present - if not job_owners == [expected_owner]: + if len(owner_rows) != len(set(job_ids)): + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail="One or more jobs do not exist.", + ) + if not all( + row["Owner"] == user_info.preferred_username + and row["VO"] == user_info.vo + for row in owner_rows + ): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have the rights to modify a pilot.", + detail=( + "Insufficient permissions to access all of the provided jobs." + ), ) - # This is for example when we submit pilots, we use the user VO, so no need to verify - if pilot_db and pilot_stamps: - # Else, check its VO + # If pilot stamps are provided, verify they all belong to the user's VO. + if pilot_db is not None and pilot_stamps: pilots = await get_pilots_by_stamp( pilot_db=pilot_db, pilot_stamps=pilot_stamps, parameters=["VO"], - allow_missing=True, ) - - if len(pilots) != len(pilot_stamps): + if len(pilots) != len(set(pilot_stamps)): raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, + status_code=status.HTTP_404_NOT_FOUND, detail="At least one pilot does not exist.", ) - if not all(pilot["VO"] == user_info.vo for pilot in pilots): raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, - detail="You don't have access to all pilots.", + detail=( + "Insufficient permissions to access all of the provided pilots." + ), ) diff --git a/diracx-routers/src/diracx/routers/pilots/management.py b/diracx-routers/src/diracx/routers/pilots/management.py index 7f892335a..c59bc8895 100644 --- a/diracx-routers/src/diracx/routers/pilots/management.py +++ b/diracx-routers/src/diracx/routers/pilots/management.py @@ -5,21 +5,16 @@ from fastapi import Body, Depends, HTTPException, Query, status -from diracx.core.exceptions import ( - PilotAlreadyExistsError, -) -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus +from diracx.core.models.pilot import PilotMetadata, PilotStatus from diracx.core.properties import GENERIC_PILOT, JOB_ADMINISTRATOR -from diracx.db.sql import JobDB, PilotAgentsDB +from diracx.db.sql import PilotAgentsDB from diracx.logic.pilots.management import ( delete_pilots as delete_pilots_bl, ) from diracx.logic.pilots.management import ( - get_pilot_jobs_ids_by_stamp, register_new_pilots, - update_pilots_fields, + update_pilots_metadata, ) -from diracx.logic.pilots.query import get_pilot_ids_by_job_id from diracx.routers.utils.users import AuthorizedUserInfo, verify_dirac_access_token from ..fastapi_classes import DiracxRouter @@ -32,11 +27,11 @@ @router.post("/") -async def add_pilot_stamps( +async def register_pilots( pilot_db: PilotAgentsDB, pilot_stamps: Annotated[ list[str], - Body(description="List of the pilot stamps we want to add to the db."), + Body(description="Stamps of the pilots to create."), ], vo: Annotated[str, Body(description="Pilot virtual organization.")], check_permissions: CheckPilotManagementPolicyCallable, @@ -51,133 +46,85 @@ async def add_pilot_stamps( Body(description="Association of a pilot reference with a pilot stamp."), ] = None, pilot_status: Annotated[ - PilotStatus, Body(description="Status of the pilots.") + PilotStatus, Body(description="Initial status of the pilots.") ] = PilotStatus.SUBMITTED, ): - """Endpoint where a you can create pilots with their references. + """Register a batch of pilots with their references. - If a pilot stamp already exists, it will block the insertion. + If any stamp already exists, the whole batch is rejected with a 409. """ - # TODO: Verify that grid types, sites, destination sites, etc. are valids + # TODO: Verify that grid types, sites, destination sites, etc. are valid await check_permissions( action=ActionType.MANAGE_PILOTS, allow_legacy_pilots=True, # dirac-admin-add-pilot ) - # Prevent someone who stole a pilot X509 to create thousands of pilots at a time - # (It would be still able to create thousands of pilots, but slower) - if GENERIC_PILOT in user_info.properties: - if len(pilot_stamps) != 1: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="As a pilot, you can only create yourself.", - ) - - if JOB_ADMINISTRATOR not in user_info.properties: - if not vo == user_info.vo: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, - detail="You can create pilots only for your VO.", - ) + # Limit the damage a stolen pilot credential can do: a pilot identity + # can only register a single stamp per call. + if GENERIC_PILOT in user_info.properties and len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions to register more than one pilot.", + ) - try: - await register_new_pilots( - pilot_db=pilot_db, - pilot_stamps=pilot_stamps, - vo=vo, - grid_type=grid_type, - grid_site=grid_site, - destination_site=destination_site, - pilot_job_references=pilot_references, - status=pilot_status, + if JOB_ADMINISTRATOR not in user_info.properties and vo != user_info.vo: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Pilots can only be registered for your own VO.", ) - except PilotAlreadyExistsError as e: - raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(e)) from e + + await register_new_pilots( + pilot_db=pilot_db, + pilot_stamps=pilot_stamps, + vo=vo, + grid_type=grid_type, + grid_site=grid_site, + destination_site=destination_site, + pilot_job_references=pilot_references, + status=pilot_status, + ) @router.delete("/", status_code=HTTPStatus.NO_CONTENT) async def delete_pilots( pilot_db: PilotAgentsDB, check_permissions: CheckPilotManagementPolicyCallable, - user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], pilot_stamps: Annotated[ - list[str] | None, Query(description="Stamps of the pilots we want to delete.") - ] = None, - age_in_days: Annotated[ - int | None, - Query( - description=( - "The number of days that define the maximum age of pilots to be deleted." - "Pilots older than this age will be considered for deletion." - ) - ), - ] = None, - delete_only_aborted: Annotated[ - bool, - Query( - description=( - "Flag indicating whether to only delete pilots whose status is 'Aborted'." - "If set to True, only pilots with the 'Aborted' status will be deleted." - "It is set by default as True to avoid any mistake." - "This flag is only used for deletion by time." - ) - ), - ] = False, + list[str], Query(description="Stamps of the pilots to delete.", min_length=1) + ], ): - """Endpoint to delete a pilot. + """Delete pilots by stamp. - Two features: + Deletes the pilot rows as well as their logs and job associations. - 1. Or you provide pilot_stamps, so you can delete pilots by their stamp - 2. Or you provide age_in_days, so you can delete pilots that lived more than age_in_days days. - - Note: If you delete a pilot, its logs and its associations with jobs WILL be deleted. + Age-based retention cleanup is deliberately *not* exposed here: it is + handled by the maintenance task worker. See + `diracx.logic.pilots.management.delete_pilots`. """ - vo_constraint: str | None = None - - # If we delete by pilot_stamps, we check that we can access them - # Else, we add a constraint to the request, to avoid deleting pilots from another VO - if pilot_stamps: - await check_permissions( - action=ActionType.MANAGE_PILOTS, - pilot_db=pilot_db, - pilot_stamps=pilot_stamps, - ) - else: - vo_constraint = user_info.vo - - if not pilot_stamps and not age_in_days: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="pilot_stamps or age_in_days have to be provided.", - ) - - await delete_pilots_bl( + await check_permissions( + action=ActionType.MANAGE_PILOTS, pilot_db=pilot_db, pilot_stamps=pilot_stamps, - age_in_days=age_in_days, - delete_only_aborted=delete_only_aborted, - vo_constraint=vo_constraint, ) + await delete_pilots_bl(pilot_db=pilot_db, pilot_stamps=pilot_stamps) + -EXAMPLE_UPDATE_FIELDS = { +EXAMPLE_UPDATE_METADATA = { "Update the BenchMark field": { "summary": "Update BenchMark", "description": "Update only the BenchMark for one pilot.", "value": { - "pilot_stamps_to_fields_mapping": [ - {"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0} - ] + "pilot_metadata": [{"PilotStamp": "the_pilot_stamp", "BenchMark": 1.0}] }, }, "Update multiple statuses": { "summary": "Update multiple pilots", - "description": "Update multiple pilots statuses.", + "description": "Update statuses for multiple pilots at once.", "value": { - "pilot_stamps_to_fields_mapping": [ - {"PilotStamp": "the_first_pilot_stamp", "Status": "Waiting"}, - {"PilotStamp": "the_second_pilot_stamp", "Status": "Waiting"}, + "pilot_metadata": [ + {"PilotStamp": "first_stamp", "Status": "Waiting"}, + {"PilotStamp": "second_stamp", "Status": "Waiting"}, ] }, }, @@ -185,25 +132,25 @@ async def delete_pilots( @router.patch("/metadata", status_code=HTTPStatus.NO_CONTENT) -async def update_pilot_fields( - pilot_stamps_to_fields_mapping: Annotated[ - list[PilotFieldsMapping], +async def update_pilot_metadata( + pilot_metadata: Annotated[ + list[PilotMetadata], Body( - description="(pilot_stamp, pilot_fields) mapping to change.", + description="Pilot metadata mappings to apply.", embed=True, - openapi_examples=EXAMPLE_UPDATE_FIELDS, # type: ignore + openapi_examples=EXAMPLE_UPDATE_METADATA, # type: ignore ), ], pilot_db: PilotAgentsDB, check_permissions: CheckPilotManagementPolicyCallable, user_info: Annotated[AuthorizedUserInfo, Depends(verify_dirac_access_token)], ): - """Modify a field of a pilot. + """Update pilot metadata (status, benchmark, etc.). - Note: Only the fields in PilotFieldsMapping are mutable, except for the PilotStamp. + Only fields defined in `PilotMetadata` are mutable. `PilotStamp` + identifies the row and cannot be changed. """ - # Ensures stamps validity - pilot_stamps = [mapping.PilotStamp for mapping in pilot_stamps_to_fields_mapping] + pilot_stamps = [m.PilotStamp for m in pilot_metadata] await check_permissions( action=ActionType.MANAGE_PILOTS, pilot_db=pilot_db, @@ -211,54 +158,15 @@ async def update_pilot_fields( allow_legacy_pilots=True, # dirac-admin-add-pilot ) - # Prevent someone who stole a pilot X509 to modify thousands of pilots at a time - # (It would be still able to modify thousands of pilots, but slower) - # We are not able to affirm that this pilot modifies itself - if GENERIC_PILOT in user_info.properties: - if len(pilot_stamps) != 1: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="As a pilot, you can only modify yourself.", - ) - - await update_pilots_fields( - pilot_db=pilot_db, - pilot_stamps_to_fields_mapping=pilot_stamps_to_fields_mapping, - ) - - -@router.get("/jobs") -async def get_pilot_jobs( - pilot_db: PilotAgentsDB, - job_db: JobDB, - check_permissions: CheckPilotManagementPolicyCallable, - pilot_stamp: Annotated[ - str | None, Query(description="The stamp of the pilot.") - ] = None, - job_id: Annotated[int | None, Query(description="The ID of the job.")] = None, -) -> list[int]: - """Endpoint only for admins, to get jobs of a pilot.""" - if pilot_stamp: - # Check VO - await check_permissions( - action=ActionType.READ_PILOT_FIELDS, - pilot_db=pilot_db, - pilot_stamps=[pilot_stamp], - ) - - return await get_pilot_jobs_ids_by_stamp( - pilot_db=pilot_db, - pilot_stamp=pilot_stamp, - ) - elif job_id: - # Check job owner - await check_permissions( - action=ActionType.READ_PILOT_FIELDS, job_db=job_db, job_ids=[job_id] + # Limit the damage a stolen pilot credential can do: a pilot identity + # can only modify a single stamp per call. + if GENERIC_PILOT in user_info.properties and len(pilot_stamps) != 1: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, + detail="Insufficient permissions to modify more than one pilot.", ) - return await get_pilot_ids_by_job_id(pilot_db=pilot_db, job_id=job_id) - - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail="You must provide either pilot_stamp or job_id", + await update_pilots_metadata( + pilot_db=pilot_db, + pilot_metadata=pilot_metadata, ) diff --git a/diracx-routers/src/diracx/routers/pilots/query.py b/diracx-routers/src/diracx/routers/pilots/query.py index 8b956ec68..c26fc5cb4 100644 --- a/diracx-routers/src/diracx/routers/pilots/query.py +++ b/diracx-routers/src/diracx/routers/pilots/query.py @@ -6,6 +6,7 @@ from fastapi import Body, Depends, Query, Response from diracx.core.models.search import SearchParams, SummaryParams +from diracx.core.properties import SERVICE_ADMINISTRATOR from diracx.db.sql import PilotAgentsDB from diracx.logic.pilots.query import MAX_PER_PAGE from diracx.logic.pilots.query import search as search_bl @@ -20,6 +21,14 @@ router = DiracxRouter() + +def _vo_constraint_for(user_info: AuthorizedUserInfo) -> str | None: + """Return the VO filter to apply for this user, or None for admins.""" + if SERVICE_ADMINISTRATOR in user_info.properties: + return None + return user_info.vo + + EXAMPLE_SEARCHES = { "Show all": { "summary": "Show all", @@ -42,6 +51,15 @@ "sort": [{"parameter": "PilotID", "direction": "asc"}], }, }, + "Pilots that ran a given job": { + "summary": "Pilots that ran a given job", + "description": ( + "Find all pilots that have run a specific job. `JobID` is a " + "pseudo-parameter resolved through `JobToPilotMapping` into a " + "`PilotID` filter; only `eq` and `in` operators are supported." + ), + "value": {"search": [{"parameter": "JobID", "operator": "eq", "value": 42}]}, + }, } @@ -55,18 +73,10 @@ "PilotID": 3, "SubmissionTime": "2023-05-25T07:03:35.602654", "LastUpdateTime": "2023-05-25T07:03:35.602656", - "Status": "RUNNING", + "Status": "Running", "GridType": "Dirac", "BenchMark": 1.0, }, - { - "PilotID": 5, - "SubmissionTime": "2023-06-25T07:03:35.602654", - "LastUpdateTime": "2023-07-25T07:03:35.602652", - "Status": "RUNNING", - "GridType": "Dirac", - "BenchMark": 63.1, - }, ] } }, @@ -80,28 +90,6 @@ } }, "model": list[dict[str, Any]], - "content": { - "application/json": { - "example": [ - { - "PilotID": 3, - "SubmissionTime": "2023-05-25T07:03:35.602654", - "LastUpdateTime": "2023-05-25T07:03:35.602656", - "Status": "RUNNING", - "GridType": "Dirac", - "BenchMark": 1.0, - }, - { - "PilotID": 5, - "SubmissionTime": "2023-06-25T07:03:35.602654", - "LastUpdateTime": "2023-07-25T07:03:35.602652", - "Status": "RUNNING", - "GridType": "Dirac", - "BenchMark": 63.1, - }, - ] - } - }, }, } @@ -118,29 +106,30 @@ async def search( SearchParams | None, Body(openapi_examples=EXAMPLE_SEARCHES) # type: ignore ] = None, ) -> list[dict[str, Any]]: - """Retrieve information about pilots.""" - # Inspired by /api/jobs/query - await check_permissions(action=ActionType.READ_PILOT_FIELDS) + """Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A `JobID` pseudo-parameter is also accepted in the `search` filter + list (operators `eq` / `in` only): it is transparently resolved + through `JobToPilotMapping` into a `PilotID` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + """ + await check_permissions(action=ActionType.READ_PILOT_METADATA) total, pilots = await search_bl( pilot_db=pilot_db, - user_vo=user_info.vo, + vo_constraint=_vo_constraint_for(user_info), page=page, per_page=per_page, body=body, ) - # Set the Content-Range header if needed - # https://datatracker.ietf.org/doc/html/rfc7233#section-4 - - # No pilots found but there are pilots for the requested search - # https://datatracker.ietf.org/doc/html/rfc7233#section-4.4 + # RFC 7233 Content-Range handling, matching /api/jobs/search if len(pilots) == 0 and total > 0: response.headers["Content-Range"] = f"pilots */{total}" response.status_code = HTTPStatus.REQUESTED_RANGE_NOT_SATISFIABLE - - # The total number of pilots is greater than the number of pilots returned - # https://datatracker.ietf.org/doc/html/rfc7233#section-4.2 elif len(pilots) < total: first_idx = per_page * (page - 1) last_idx = min(first_idx + len(pilots), total) - 1 if total > 0 else 0 @@ -156,11 +145,15 @@ async def summary( body: SummaryParams, check_permissions: CheckPilotManagementPolicyCallable, ): - """Show information suitable for plotting.""" - await check_permissions(action=ActionType.READ_PILOT_FIELDS) + """Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + """ + await check_permissions(action=ActionType.READ_PILOT_METADATA) return await summary_bl( pilot_db=pilot_db, body=body, - vo=user_info.vo, + vo_constraint=_vo_constraint_for(user_info), ) diff --git a/diracx-routers/tests/jobs/test_heartbeat_commands.py b/diracx-routers/tests/jobs/test_heartbeat_commands.py index 8ef991846..d2436ca42 100644 --- a/diracx-routers/tests/jobs/test_heartbeat_commands.py +++ b/diracx-routers/tests/jobs/test_heartbeat_commands.py @@ -19,6 +19,8 @@ "WMSAccessPolicy", "DevelopmentSettings", "JobParametersDB", + "PilotAgentsDB", + "PilotManagementAccessPolicy", ] ) diff --git a/diracx-routers/tests/jobs/test_query.py b/diracx-routers/tests/jobs/test_query.py index e40db9bd5..323ee1143 100644 --- a/diracx-routers/tests/jobs/test_query.py +++ b/diracx-routers/tests/jobs/test_query.py @@ -37,6 +37,8 @@ "WMSAccessPolicy", "DevelopmentSettings", "JobParametersDB", + "PilotAgentsDB", + "PilotManagementAccessPolicy", ] ) @@ -917,3 +919,125 @@ def test_summary_doc_example(normal_user_client: TestClient, valid_job_id: int): assert r.status_code == 200, r.json() assert len(r.json()) == 1 + + +# --------------------------------------------------------------------------- +# Cross-table search: PilotStamp pseudo-parameter on POST /api/jobs/search +# --------------------------------------------------------------------------- + + +async def _assign_pilot_to_jobs(client, stamp: str, job_ids: list[int]) -> None: + """Insert JobToPilotMapping rows directly. + + The router does not expose a public endpoint for pilot-job association + (deliberately — it waits for the DiracX pilot token story). Tests reach + into the app's dependency override to insert the rows via the DB layer. + """ + from diracx.db.sql import PilotAgentsDB + from diracx.logic.pilots.management import assign_jobs_to_pilot + + db = client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + async with db: + await assign_jobs_to_pilot(pilot_db=db, pilot_stamp=stamp, job_ids=job_ids) + + +async def _register_pilot(client, stamp: str) -> None: + r = client.post( + "/api/pilots/", + json={"pilot_stamps": [stamp], "vo": "lhcb"}, + ) + assert r.status_code == 200, r.json() + + +async def test_jobs_search_by_pilot_stamp_eq(normal_user_client): + """A ``PilotStamp`` eq filter on /jobs/search returns the jobs that ran on that pilot.""" + r = normal_user_client.post("/api/jobs/jdl", json=[TEST_JDL for _ in range(3)]) + assert r.status_code == 200, r.json() + job_ids = [j["JobID"] for j in r.json()] + + await _register_pilot(normal_user_client, "stamp-eq") + await _assign_pilot_to_jobs( + normal_user_client, "stamp-eq", [job_ids[0], job_ids[1]] + ) + + r = normal_user_client.post( + "/api/jobs/search", + json={ + "search": [ + {"parameter": "PilotStamp", "operator": "eq", "value": "stamp-eq"} + ] + }, + ) + assert r.status_code == 200, r.json() + returned = sorted(j["JobID"] for j in r.json()) + assert returned == sorted([job_ids[0], job_ids[1]]) + + +async def test_jobs_search_by_pilot_stamp_in_multiple(normal_user_client): + """An ``in`` filter over several stamps returns the union of their jobs.""" + r = normal_user_client.post("/api/jobs/jdl", json=[TEST_JDL for _ in range(4)]) + assert r.status_code == 200, r.json() + job_ids = [j["JobID"] for j in r.json()] + + await _register_pilot(normal_user_client, "stamp-in-a") + await _register_pilot(normal_user_client, "stamp-in-b") + await _assign_pilot_to_jobs(normal_user_client, "stamp-in-a", [job_ids[0]]) + await _assign_pilot_to_jobs( + normal_user_client, "stamp-in-b", [job_ids[1], job_ids[2]] + ) + + r = normal_user_client.post( + "/api/jobs/search", + json={ + "search": [ + { + "parameter": "PilotStamp", + "operator": "in", + "values": ["stamp-in-a", "stamp-in-b"], + } + ] + }, + ) + assert r.status_code == 200, r.json() + returned = sorted(j["JobID"] for j in r.json()) + assert returned == sorted([job_ids[0], job_ids[1], job_ids[2]]) + + +def test_jobs_search_by_unknown_pilot_stamp_returns_empty(normal_user_client): + """An unknown stamp resolves to an empty job list; the caller gets ``[]``.""" + r = normal_user_client.post("/api/jobs/jdl", json=[TEST_JDL]) + assert r.status_code == 200, r.json() + + r = normal_user_client.post( + "/api/jobs/search", + json={ + "search": [{"parameter": "PilotStamp", "operator": "eq", "value": "nope"}] + }, + ) + assert r.status_code == 200 + assert r.json() == [] + + +def test_jobs_search_combining_pilot_stamp_and_job_id_raises(normal_user_client): + """Combining a ``PilotStamp`` pseudo-filter with a real ``JobID`` filter is refused.""" + r = normal_user_client.post( + "/api/jobs/search", + json={ + "search": [ + {"parameter": "PilotStamp", "operator": "eq", "value": "any"}, + {"parameter": "JobID", "operator": "eq", "value": 1}, + ] + }, + ) + assert r.status_code in (400, 422), r.json() + + +def test_jobs_search_pilot_stamp_unsupported_operator_raises(normal_user_client): + """Operators other than ``eq`` / ``in`` on ``PilotStamp`` are refused.""" + r = normal_user_client.post( + "/api/jobs/search", + json={ + "search": [{"parameter": "PilotStamp", "operator": "neq", "value": "any"}] + }, + ) + assert r.status_code in (400, 422), r.json() diff --git a/diracx-routers/tests/jobs/test_status.py b/diracx-routers/tests/jobs/test_status.py index 25b69eb1f..773930b88 100644 --- a/diracx-routers/tests/jobs/test_status.py +++ b/diracx-routers/tests/jobs/test_status.py @@ -26,6 +26,8 @@ "WMSAccessPolicy", "DevelopmentSettings", "JobParametersDB", + "PilotAgentsDB", + "PilotManagementAccessPolicy", ] ) diff --git a/diracx-routers/tests/pilots/__init__.py b/diracx-routers/tests/pilots/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/diracx-routers/tests/pilots/test_access_policy.py b/diracx-routers/tests/pilots/test_access_policy.py new file mode 100644 index 000000000..f8f1e2b54 --- /dev/null +++ b/diracx-routers/tests/pilots/test_access_policy.py @@ -0,0 +1,127 @@ +"""Unit tests for `PilotManagementAccessPolicy`. + +These tests bypass the FastAPI test harness (which stubs the real policy +with `AlwaysAllowAccessPolicy`) and invoke the policy coroutine +directly, mirroring how it is called from a real request. +""" + +from __future__ import annotations + +from uuid import uuid4 + +import pytest +from fastapi import HTTPException + +from diracx.core.properties import GENERIC_PILOT, NORMAL_USER, SERVICE_ADMINISTRATOR +from diracx.db.sql.pilots.db import PilotAgentsDB +from diracx.routers.pilots.access_policies import ( + ActionType, + PilotManagementAccessPolicy, +) +from diracx.routers.utils.users import AuthorizedUserInfo + +MAIN_VO = "lhcb" + + +def _user(*properties, vo: str = MAIN_VO) -> AuthorizedUserInfo: + """Build a minimal AuthorizedUserInfo for policy tests.""" + return AuthorizedUserInfo( + bearer_token="", + token_id=str(uuid4()), + properties=list(properties), + sub="testingVO:sub", + preferred_username="test-user", + dirac_group="test_group", + vo=vo, + policies={}, + ) + + +@pytest.fixture +async def pilot_db_with_pilots(): + """Yield a pilot DB seeded with two pilots, both in MAIN_VO.""" + db = PilotAgentsDB("sqlite+aiosqlite:///:memory:") + async with db.engine_context(): + async with db.engine.begin() as conn: + await conn.run_sync(db.metadata.create_all) + async with db as pdb: + await pdb.register_pilots(pilot_stamps=["stamp-a", "stamp-b"], vo=MAIN_VO) + yield db + + +async def test_manage_requires_service_administrator(): + """A normal user cannot manage pilots.""" + with pytest.raises(HTTPException) as exc_info: + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(NORMAL_USER), + action=ActionType.MANAGE_PILOTS, + ) + assert exc_info.value.status_code == 403 + + +async def test_manage_allows_service_administrator(): + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(SERVICE_ADMINISTRATOR), + action=ActionType.MANAGE_PILOTS, + ) + + +async def test_manage_allows_legacy_pilot_when_opted_in(): + """`allow_legacy_pilots=True` lets GENERIC_PILOT identities manage.""" + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(GENERIC_PILOT), + action=ActionType.MANAGE_PILOTS, + allow_legacy_pilots=True, + ) + + +async def test_manage_rejects_legacy_pilot_when_not_opted_in(): + with pytest.raises(HTTPException) as exc_info: + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(GENERIC_PILOT), + action=ActionType.MANAGE_PILOTS, + ) + assert exc_info.value.status_code == 403 + + +async def test_read_denies_generic_pilots(): + """A pilot identity is not allowed to read other pilots' metadata.""" + with pytest.raises(HTTPException) as exc_info: + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(GENERIC_PILOT), + action=ActionType.READ_PILOT_METADATA, + ) + assert exc_info.value.status_code == 403 + + +async def test_pilot_stamp_check_raises_404_on_unknown(pilot_db_with_pilots): + """Supplying an unknown pilot stamp must surface as 404.""" + async with pilot_db_with_pilots as db: + with pytest.raises(HTTPException) as exc_info: + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(SERVICE_ADMINISTRATOR), + action=ActionType.MANAGE_PILOTS, + pilot_db=db, + pilot_stamps=["stamp-a", "nope"], + ) + assert exc_info.value.status_code == 404 + + +async def test_pilot_stamp_check_raises_403_on_cross_vo(pilot_db_with_pilots): + """A user from another VO must not be able to act on this VO's pilots.""" + async with pilot_db_with_pilots as db: + with pytest.raises(HTTPException) as exc_info: + await PilotManagementAccessPolicy.policy( + "PilotManagementAccessPolicy", + _user(SERVICE_ADMINISTRATOR, vo="other-vo"), + action=ActionType.MANAGE_PILOTS, + pilot_db=db, + pilot_stamps=["stamp-a"], + ) + assert exc_info.value.status_code == 403 diff --git a/diracx-routers/tests/pilots/test_pilot_creation.py b/diracx-routers/tests/pilots/test_pilot_creation.py index d0cea485e..3ea87d088 100644 --- a/diracx-routers/tests/pilots/test_pilot_creation.py +++ b/diracx-routers/tests/pilots/test_pilot_creation.py @@ -1,22 +1,17 @@ -from __future__ import annotations +"""Router-level tests for pilot register / update / delete.""" -from datetime import datetime, timezone +from __future__ import annotations import pytest -from sqlalchemy import update -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus -from diracx.db.sql import PilotAgentsDB -from diracx.db.sql.pilots.schema import PilotAgents +from diracx.core.models.pilot import PilotMetadata, PilotStatus pytestmark = pytest.mark.enabled_dependencies( [ - "PilotCredentialsAccessPolicy", "DevelopmentSettings", "AuthDB", "AuthSettings", "ConfigSource", - "BaseAccessPolicy", "PilotAgentsDB", "PilotManagementAccessPolicy", "JobDB", @@ -24,7 +19,6 @@ ) MAIN_VO = "lhcb" -N = 100 @pytest.fixture @@ -33,249 +27,121 @@ def normal_test_client(client_factory): yield client -async def test_create_pilots(normal_test_client): - # Lots of request, to validate that it returns the credentials in the same order as the input references - pilot_stamps = [f"stamps_{i}" for i in range(N)] - - # -------------- Bulk insert -------------- - body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} +async def test_register_then_duplicate_then_success(normal_test_client): + """Registering an existing stamp is rejected with 409; a fresh one still succeeds.""" + pilot_stamps = [f"stamps_{i}" for i in range(5)] r = normal_test_client.post( - "/api/pilots/", - json=body, + "/api/pilots/", json={"pilot_stamps": pilot_stamps, "vo": MAIN_VO} ) - assert r.status_code == 200, r.json() - # -------------- Register a pilot that already exists, and one that does not -------------- - - body = { - "pilot_stamps": [pilot_stamps[0], pilot_stamps[0] + "_new_one"], - "vo": MAIN_VO, - } - - r = normal_test_client.post( - "/api/pilots/", - json=body, - headers={ - "Content-Type": "application/json", - }, - ) - - assert r.status_code == 409 - assert ( - r.json()["detail"] - == f"The following pilots already exist: {{'{pilot_stamps[0]}'}}" - ) - - # -------------- Register a pilot that does not exists **but** was called before in an error -------------- - # To prove that, if I tried to register a pilot that does not exist with one that already exists, - # i can normally add the one that did not exist before (it should not have added it before) - body = {"pilot_stamps": [pilot_stamps[0] + "_new_one"], "vo": MAIN_VO} - + # Mix of existing and new stamps: whole batch is rejected r = normal_test_client.post( "/api/pilots/", - json=body, - headers={ - "Content-Type": "application/json", + json={ + "pilot_stamps": [pilot_stamps[0], "stamps_new"], + "vo": MAIN_VO, }, ) + assert r.status_code == 409, r.json() - assert r.status_code == 200 - - -async def test_create_pilot_and_delete_it(normal_test_client): - pilot_stamp = "stamps_1" - - # -------------- Insert -------------- - body = {"pilot_stamps": [pilot_stamp], "vo": MAIN_VO} - - # Create a pilot + # The new stamp alone was NOT committed by the failing call above r = normal_test_client.post( - "/api/pilots/", - json=body, + "/api/pilots/", json={"pilot_stamps": ["stamps_new"], "vo": MAIN_VO} ) - assert r.status_code == 200, r.json() - # -------------- Duplicate -------------- - # Duplicate because it exists, should have 409 - r = normal_test_client.post( - "/api/pilots/", - json=body, - ) - - assert r.status_code == 409, r.json() - # -------------- Delete -------------- - params = {"pilot_stamps": [pilot_stamp]} - - # We delete the pilot - r = normal_test_client.delete( - "/api/pilots/", - params=params, +async def test_register_delete_by_stamp_roundtrip(normal_test_client): + r = normal_test_client.post( + "/api/pilots/", json={"pilot_stamps": ["stamp_a"], "vo": MAIN_VO} ) + assert r.status_code == 200 + r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": ["stamp_a"]}) assert r.status_code == 204 - # -------------- Insert -------------- - # Create a the same pilot, but works because it does not exist anymore + # Now the stamp is free again r = normal_test_client.post( - "/api/pilots/", - json=body, + "/api/pilots/", json={"pilot_stamps": ["stamp_a"], "vo": MAIN_VO} ) + assert r.status_code == 200 - assert r.status_code == 200, r.json() - - -async def test_create_pilot_and_modify_it(normal_test_client): - pilot_stamps = ["stamps_1", "stamp_2"] - - # -------------- Insert -------------- - body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} - # Create pilots +async def test_update_pilot_metadata_applies_partial_fields(normal_test_client): + """PATCH /pilots/metadata supports heterogeneous field subsets per row.""" + stamps = ["stamp_m1", "stamp_m2"] r = normal_test_client.post( - "/api/pilots/", - json=body, + "/api/pilots/", json={"pilot_stamps": stamps, "vo": MAIN_VO} ) + assert r.status_code == 200 - assert r.status_code == 200, r.json() - - # -------------- Modify -------------- - # We modify only the first pilot - body = { - "pilot_stamps_to_fields_mapping": [ - PilotFieldsMapping( - PilotStamp=pilot_stamps[0], - BenchMark=1.0, - StatusReason="NewReason", - AccountingSent=True, - Status=PilotStatus.WAITING, - ).model_dump(exclude_unset=True) - ] - } - - r = normal_test_client.patch("/api/pilots/metadata", json=body) - - assert r.status_code == 204 - - body = { - "parameters": [], - "search": [], - "sort": [], - "distinct": True, - } - - r = normal_test_client.post("/api/pilots/search", json=body) - assert r.status_code == 200, r.json() - pilot1 = r.json()[0] - pilot2 = r.json()[1] - - assert pilot1["BenchMark"] == 1.0 - assert pilot1["StatusReason"] == "NewReason" - assert pilot1["AccountingSent"] - assert pilot1["Status"] == PilotStatus.WAITING - - assert pilot2["BenchMark"] != pilot1["BenchMark"] - assert pilot2["StatusReason"] != pilot1["StatusReason"] - assert pilot2["AccountingSent"] != pilot1["AccountingSent"] - assert pilot2["Status"] != pilot1["Status"] - - -@pytest.mark.asyncio -async def test_delete_pilots_by_age_and_stamp(normal_test_client): - # Generate 100 pilot stamps - pilot_stamps = [f"stamp_{i}" for i in range(100)] - - # -------------- Insert all pilots -------------- - body = {"pilot_stamps": pilot_stamps, "vo": MAIN_VO} - r = normal_test_client.post("/api/pilots/", json=body) - assert r.status_code == 200, r.json() - - # -------------- Modify last 50 pilots' fields -------------- - to_modify = pilot_stamps[50:] - mappings = [] - for idx, stamp in enumerate(to_modify): - # First 25 of modified set to ABORTED, others to WAITING - status = PilotStatus.ABORTED if idx < 25 else PilotStatus.WAITING - mapping = PilotFieldsMapping( - PilotStamp=stamp, - BenchMark=idx + 0.1, - StatusReason=f"Reason_{idx}", - AccountingSent=(idx % 2 == 0), - Status=status, - ).model_dump(exclude_unset=True) - mappings.append(mapping) - + # stamp_m1 updates only BenchMark; stamp_m2 only Status r = normal_test_client.patch( "/api/pilots/metadata", - json={"pilot_stamps_to_fields_mapping": mappings}, + json={ + "pilot_metadata": [ + PilotMetadata(PilotStamp="stamp_m1", BenchMark=1.0).model_dump( + exclude_unset=True + ), + PilotMetadata( + PilotStamp="stamp_m2", Status=PilotStatus.WAITING + ).model_dump(exclude_unset=True), + ] + }, ) - assert r.status_code == 204 - - # -------------- Directly set SubmissionTime to March 14, 2003 for last 50 -------------- - old_date = datetime(2003, 3, 14, tzinfo=timezone.utc) - # Access DB session from normal_test_client fixtures - db = normal_test_client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + assert r.status_code == 204, r.json() - async with db: - stmt = ( - update(PilotAgents) - .where(PilotAgents.pilot_stamp.in_(to_modify)) - .values(SubmissionTime=old_date) - ) - await db.conn.execute(stmt) - await db.conn.commit() - - # -------------- Verify all 100 pilots exist -------------- - search_body = {"parameters": [], "search": [], "sort": [], "distinct": True} - r = normal_test_client.post("/api/pilots/search", json=search_body) - assert r.status_code == 200, r.json() - assert len(r.json()) == 100 - - # -------------- 1) Delete only old aborted pilots (25 expected) -------------- - # age_in_days large enough to include 2003-03-14 + r = normal_test_client.post("/api/pilots/search", json={}) + assert r.status_code == 200 + by_stamp = {p["PilotStamp"]: p for p in r.json()} + assert by_stamp["stamp_m1"]["BenchMark"] == 1.0 + assert by_stamp["stamp_m1"]["Status"] == PilotStatus.SUBMITTED # untouched + assert by_stamp["stamp_m2"]["Status"] == PilotStatus.WAITING + assert by_stamp["stamp_m2"]["BenchMark"] == 0.0 # untouched + + +async def test_delete_unknown_stamp_is_a_noop(normal_test_client): + """Deleting an unknown stamp is a safe no-op under the test harness. + + The test harness replaces `PilotManagementAccessPolicy` with + `AlwaysAllowAccessPolicy`, so the real policy's unknown-stamp 404 + branch is exercised by the dedicated policy unit test + (`test_access_policy.py`). Here we only verify the router path does + not explode and is safely idempotent. + """ r = normal_test_client.delete( - "/api/pilots/", - params={"age_in_days": 15, "delete_only_aborted": True}, + "/api/pilots/", params={"pilot_stamps": ["does_not_exist"]} ) assert r.status_code == 204 - # Expect 75 remaining - r = normal_test_client.post("/api/pilots/search", json=search_body) - assert len(r.json()) == 75 - # -------------- 2) Delete all old pilots (remaining 25 old) -------------- - r = normal_test_client.delete( - "/api/pilots/", - params={"age_in_days": 15}, - ) - assert r.status_code == 204 - # Expect 50 remaining - r = normal_test_client.post("/api/pilots/search", json=search_body) - assert len(r.json()) == 50 +async def test_delete_requires_at_least_one_stamp(normal_test_client): + """DELETE with no stamps must return 422 (FastAPI validation).""" + r = normal_test_client.delete("/api/pilots/") + assert r.status_code == 422, r.json() - # -------------- 3) Delete one recent pilot by stamp -------------- - one_stamp = pilot_stamps[10] - r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": [one_stamp]}) - assert r.status_code == 204 - # Expect 49 remaining - r = normal_test_client.post("/api/pilots/search", json=search_body) - assert len(r.json()) == 49 - # -------------- 4) Delete all remaining pilots -------------- - # Collect remaining stamps - remaining = [p["PilotStamp"] for p in r.json()] - r = normal_test_client.delete("/api/pilots/", params={"pilot_stamps": remaining}) - assert r.status_code == 204 - # Expect none remaining - r = normal_test_client.post("/api/pilots/search", json=search_body) - assert r.status_code == 200 - assert len(r.json()) == 0 +async def test_unknown_query_params_do_not_trigger_deletion(normal_test_client): + """Age-based cleanup is handled by the task worker, not the HTTP API. - # -------------- 5) Attempt deleting unknown pilot, expect 400 -------------- - r = normal_test_client.delete( - "/api/pilots/", params={"pilot_stamps": ["unknown_stamp"]} + The router must NOT accept age_in_days; any such param is either + ignored by FastAPI or returns 422 on unexpected query usage. The key + observation is that passing `age_in_days` alone (without + `pilot_stamps`) must not silently wipe pilots. + """ + # Create a pilot to ensure there's something that could be deleted + r = normal_test_client.post( + "/api/pilots/", json={"pilot_stamps": ["stamp_safe"], "vo": MAIN_VO} ) - assert r.status_code == 204 + assert r.status_code == 200 + + # age_in_days alone is rejected because pilot_stamps is required + r = normal_test_client.delete("/api/pilots/", params={"age_in_days": 1}) + assert r.status_code == 422 + + # Our pilot is still there + r = normal_test_client.post("/api/pilots/search", json={}) + assert r.status_code == 200 + assert any(p["PilotStamp"] == "stamp_safe" for p in r.json()) diff --git a/diracx-routers/tests/pilots/test_query.py b/diracx-routers/tests/pilots/test_query.py index c6c1c7e35..03a7a8082 100644 --- a/diracx-routers/tests/pilots/test_query.py +++ b/diracx-routers/tests/pilots/test_query.py @@ -1,20 +1,12 @@ -"""Inspired by pilots and jobs db search tests.""" +"""Router-level tests for pilots search/summary and JobID pseudo-parameter.""" from __future__ import annotations import pytest -from fastapi.testclient import TestClient - -from diracx.core.exceptions import InvalidQueryError -from diracx.core.models.pilot import PilotFieldsMapping, PilotStatus -from diracx.core.models.search import ( - ScalarSearchOperator, - ScalarSearchSpec, - SortDirection, - SortSpec, - VectorSearchOperator, - VectorSearchSpec, -) + +from diracx.core.models.pilot import PilotMetadata, PilotStatus +from diracx.db.sql import PilotAgentsDB +from diracx.logic.pilots.management import assign_jobs_to_pilot pytestmark = pytest.mark.enabled_dependencies( [ @@ -23,391 +15,177 @@ "DevelopmentSettings", "PilotAgentsDB", "PilotManagementAccessPolicy", + "JobDB", ] ) +MAIN_VO = "lhcb" +N = 20 + +PILOT_STATUSES = list(PilotStatus) + + @pytest.fixture def normal_test_client(client_factory): with client_factory.normal_user() as client: yield client -MAIN_VO = "lhcb" -N = 100 - -PILOT_REASONS = [ - "I was sick", - "I can't, I have a pony.", - "I was shopping", - "I was sleeping", -] - -PILOT_STATUSES = list(PilotStatus) - - @pytest.fixture async def populated_pilot_client(normal_test_client): + """Client with N pilots registered and metadata patched.""" pilot_stamps = [f"stamp_{i}" for i in range(1, N + 1)] - # -------------- Bulk insert -------------- - body = {"vo": MAIN_VO, "pilot_stamps": pilot_stamps} - r = normal_test_client.post( "/api/pilots/", - json=body, + json={"vo": MAIN_VO, "pilot_stamps": pilot_stamps}, ) - assert r.status_code == 200, r.json() - body = { - "pilot_stamps_to_fields_mapping": [ - PilotFieldsMapping( - PilotStamp=pilot_stamp, - BenchMark=i**2, - StatusReason=PILOT_REASONS[i % len(PILOT_REASONS)], - AccountingSent=True, - Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], - CurrentJobID=i, - Queue=f"queue_{i}", - ).model_dump(exclude_unset=True) - for i, pilot_stamp in enumerate(pilot_stamps) - ] - } - - r = normal_test_client.patch("/api/pilots/metadata", json=body) + r = normal_test_client.patch( + "/api/pilots/metadata", + json={ + "pilot_metadata": [ + PilotMetadata( + PilotStamp=stamp, + BenchMark=float(i), + Status=PILOT_STATUSES[i % len(PILOT_STATUSES)], + Queue=f"queue_{i}", + ).model_dump(exclude_unset=True) + for i, stamp in enumerate(pilot_stamps) + ] + }, + ) + assert r.status_code == 204, r.json() + return normal_test_client - assert r.status_code == 204 - yield normal_test_client +def test_search_returns_pilots_from_own_vo(populated_pilot_client): + r = populated_pilot_client.post("/api/pilots/search", json={}) + assert r.status_code == 200, r.json() + pilots = r.json() + assert len(pilots) == N + assert all(p["VO"] == MAIN_VO for p in pilots) -async def test_pilot_summary(populated_pilot_client: TestClient): - # Group by StatusReason +def test_search_filter_by_status(populated_pilot_client): r = populated_pilot_client.post( - "/api/pilots/summary", + "/api/pilots/search", json={ - "grouping": ["StatusReason"], + "parameters": ["PilotStamp", "Status"], + "search": [ + { + "parameter": "Status", + "operator": "eq", + "value": PilotStatus.WAITING.value, + } + ], }, ) + assert r.status_code == 200, r.json() + pilots = r.json() + assert all(p["Status"] == PilotStatus.WAITING for p in pilots) - assert r.status_code == 200 - assert sum([el["count"] for el in r.json()]) == N - assert len(r.json()) == len(PILOT_REASONS) +def test_search_pagination_content_range(populated_pilot_client): + r = populated_pilot_client.post( + "/api/pilots/search?per_page=5&page=1", + json={}, + ) + assert r.status_code == 206 + assert "Content-Range" in r.headers + assert r.headers["Content-Range"] == f"pilots 0-4/{N}" + assert len(r.json()) == 5 + - # Group by CurrentJobID +def test_summary_groups_by_status(populated_pilot_client): r = populated_pilot_client.post( - "/api/pilots/summary", - json={ - "grouping": ["CurrentJobID"], - }, + "/api/pilots/summary", json={"grouping": ["Status"]} ) + assert r.status_code == 200, r.json() + totals = {row["Status"]: row["count"] for row in r.json()} + assert sum(totals.values()) == N - assert r.status_code == 200 - assert all(el["count"] == 1 for el in r.json()) - assert len(r.json()) == N +# --------------------------------------------------------------------------- +# Cross-table search: JobID pseudo-parameter on POST /api/pilots/search +# --------------------------------------------------------------------------- + + +async def _assign(client, stamp: str, job_ids: list[int]) -> None: + """Insert JobToPilotMapping rows directly via the DB dependency override.""" + db = client.app.dependency_overrides[PilotAgentsDB.transaction].args[0] + async with db: + await assign_jobs_to_pilot(pilot_db=db, pilot_stamp=stamp, job_ids=job_ids) + + +async def test_pilots_search_by_job_id_eq(populated_pilot_client): + """A `JobID` eq filter returns only the pilots that ran that job.""" + await _assign(populated_pilot_client, "stamp_1", [100]) + await _assign(populated_pilot_client, "stamp_2", [100]) + await _assign(populated_pilot_client, "stamp_3", [200]) - # Group by CurrentJobID where BenchMark < 10^2 r = populated_pilot_client.post( - "/api/pilots/summary", + "/api/pilots/search", json={ - "grouping": ["CurrentJobID"], - "search": [{"parameter": "BenchMark", "operator": "lt", "value": 10**2}], + "parameters": ["PilotStamp"], + "search": [{"parameter": "JobID", "operator": "eq", "value": 100}], }, ) - assert r.status_code == 200, r.json() + stamps = sorted(p["PilotStamp"] for p in r.json()) + assert stamps == ["stamp_1", "stamp_2"] - assert all(el["count"] == 1 for el in r.json()) - assert len(r.json()) == 10 +async def test_pilots_search_by_job_id_in(populated_pilot_client): + """An `in` filter over several job IDs returns the union of their pilots.""" + await _assign(populated_pilot_client, "stamp_4", [300]) + await _assign(populated_pilot_client, "stamp_5", [301]) -@pytest.fixture -async def search(populated_pilot_client): - async def _search( - parameters, conditions, sorts, distinct=False, page=1, per_page=100 - ): - body = { - "parameters": parameters, - "search": conditions, - "sort": sorts, - "distinct": distinct, - } - - params = {"per_page": per_page, "page": page} - - r = populated_pilot_client.post("/api/pilots/search", json=body, params=params) - - if r.status_code in (400, 422): - # If we have a status_code 400/422, that means that the query failed - raise InvalidQueryError() - - return r.json(), r.headers - - return _search - - -async def test_search_parameters(search): - """Test that we can search specific parameters for pilots.""" - # Search a specific parameter: PilotID - result, headers = await search(["PilotID"], [], []) - assert len(result) == N - assert result - for r in result: - assert r.keys() == {"PilotID"} - assert "Content-Range" not in headers - - # Search a specific parameter: Status - result, headers = await search(["Status"], [], []) - assert len(result) == N - assert result - for r in result: - assert r.keys() == {"Status"} - assert "Content-Range" not in headers - - # Search for multiple parameters: PilotID, Status - result, headers = await search(["PilotID", "Status"], [], []) - assert len(result) == N - assert result - for r in result: - assert r.keys() == {"PilotID", "Status"} - assert "Content-Range" not in headers - - # Search for a specific parameter but use distinct: Status - result, headers = await search(["Status"], [], [], distinct=True) - assert len(result) == len(PILOT_STATUSES) - assert result - assert "Content-Range" not in headers - - # Search for a non-existent parameter: Dummy - with pytest.raises(InvalidQueryError): - result, headers = await search(["Dummy"], [], []) - - -async def test_search_conditions(search): - """Test that we can search for specific pilots.""" - # Search a specific scalar condition: PilotID eq 3 - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=3 - ) - result, headers = await search([], [condition], []) - assert len(result) == 1 - assert result - assert len(result) == 1 - assert result[0]["PilotID"] == 3 - assert "Content-Range" not in headers - - # Search a specific scalar condition: PilotID lt 3 - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.LESS_THAN, value=3 - ) - result, headers = await search([], [condition], []) - assert len(result) == 2 - assert result - assert len(result) == 2 - assert result[0]["PilotID"] == 1 - assert result[1]["PilotID"] == 2 - assert "Content-Range" not in headers - - # Search a specific scalar condition: PilotID neq 3 - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.NOT_EQUAL, value=3 - ) - result, headers = await search([], [condition], []) - assert len(result) == 99 - assert result - assert len(result) == 99 - assert all(r["PilotID"] != 3 for r in result) - assert "Content-Range" not in headers - - # Search a specific scalar condition: PilotID eq 5873 (does not exist) - condition = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=5873 + r = populated_pilot_client.post( + "/api/pilots/search", + json={ + "parameters": ["PilotStamp"], + "search": [ + { + "parameter": "JobID", + "operator": "in", + "values": [300, 301], + } + ], + }, ) - result, headers = await search([], [condition], []) - assert not result - assert "Content-Range" not in headers + assert r.status_code == 200, r.json() + stamps = sorted(p["PilotStamp"] for p in r.json()) + assert stamps == ["stamp_4", "stamp_5"] - # Search a specific vector condition: PilotID in 1,2,3 - condition = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 3] - ) - result, headers = await search([], [condition], []) - assert len(result) == 3 - assert result - assert len(result) == 3 - assert all(r["PilotID"] in [1, 2, 3] for r in result) - assert "Content-Range" not in headers - - # Search a specific vector condition: PilotID in 1,2,5873 (one of them does not exist) - condition = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[1, 2, 5873] - ) - result, headers = await search([], [condition], []) - assert len(result) == 2 - assert result - assert len(result) == 2 - assert all(r["PilotID"] in [1, 2] for r in result) - assert "Content-Range" not in headers - - # Search a specific vector condition: PilotID not in 1,2,3 - condition = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.NOT_IN, values=[1, 2, 3] - ) - result, headers = await search([], [condition], []) - assert len(result) == 97 - assert result - assert len(result) == 97 - assert all(r["PilotID"] not in [1, 2, 3] for r in result) - assert "Content-Range" not in headers - - # Search a specific vector condition: PilotID not in 1,2,5873 (one of them does not exist) - condition = VectorSearchSpec( - parameter="PilotID", - operator=VectorSearchOperator.NOT_IN, - values=[1, 2, 5873], - ) - result, headers = await search([], [condition], []) - assert len(result) == 98 - assert result - assert len(result) == 98 - assert all(r["PilotID"] not in [1, 2] for r in result) - assert "Content-Range" not in headers - - # Search for multiple conditions based on different parameters: PilotID eq 70, PilotID in 4,5,6 - condition1 = ScalarSearchSpec( - parameter="PilotStamp", operator=ScalarSearchOperator.EQUAL, value="stamp_5" - ) - condition2 = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + +def test_pilots_search_by_unknown_job_id_returns_empty(populated_pilot_client): + r = populated_pilot_client.post( + "/api/pilots/search", + json={"search": [{"parameter": "JobID", "operator": "eq", "value": 999999}]}, ) - result, headers = await search([], [condition1, condition2], []) + assert r.status_code == 200 + assert r.json() == [] - assert result - assert len(result) == 1 - assert result[0]["PilotID"] == 5 - assert result[0]["PilotStamp"] == "stamp_5" - assert "Content-Range" not in headers - # Search for multiple conditions based on the same parameter: PilotID eq 70, PilotID in 4,5,6 - condition1 = ScalarSearchSpec( - parameter="PilotID", operator=ScalarSearchOperator.EQUAL, value=70 +def test_pilots_search_job_id_unsupported_operator_raises(populated_pilot_client): + r = populated_pilot_client.post( + "/api/pilots/search", + json={"search": [{"parameter": "JobID", "operator": "neq", "value": 1}]}, ) - condition2 = VectorSearchSpec( - parameter="PilotID", operator=VectorSearchOperator.IN, values=[4, 5, 6] + assert r.status_code in (400, 422), r.json() + + +def test_pilots_search_combining_job_id_and_pilot_id_raises(populated_pilot_client): + r = populated_pilot_client.post( + "/api/pilots/search", + json={ + "search": [ + {"parameter": "JobID", "operator": "eq", "value": 1}, + {"parameter": "PilotID", "operator": "eq", "value": 1}, + ] + }, ) - result, headers = await search([], [condition1, condition2], []) - assert len(result) == 0 - assert not result - assert "Content-Range" not in headers - - -async def test_search_sorts(search): - """Test that we can search for pilots and sort the results.""" - # Search and sort by PilotID in ascending order - sort = SortSpec(parameter="PilotID", direction=SortDirection.ASC) - result, headers = await search([], [], [sort]) - assert len(result) == N - assert result - for i, r in enumerate(result): - assert r["PilotID"] == i + 1 - assert "Content-Range" not in headers - - # Search and sort by PilotID in descending order - sort = SortSpec(parameter="PilotID", direction=SortDirection.DESC) - result, headers = await search([], [], [sort]) - assert len(result) == N - assert result - for i, r in enumerate(result): - assert r["PilotID"] == N - i - assert "Content-Range" not in headers - - # Search and sort by PilotStamp in ascending order - sort = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) - result, headers = await search([], [], [sort]) - assert len(result) == N - assert result - # Assert that stamp_10 is before stamp_2 because of the lexicographical order - assert result[2]["PilotStamp"] == "stamp_100" - assert result[12]["PilotStamp"] == "stamp_2" - assert "Content-Range" not in headers - - # Search and sort by PilotStamp in descending order - sort = SortSpec(parameter="PilotStamp", direction=SortDirection.DESC) - result, headers = await search([], [], [sort]) - assert len(result) == N - assert result - # Assert that stamp_10 is before stamp_2 because of the lexicographical order - assert result[97]["PilotStamp"] == "stamp_100" - assert result[87]["PilotStamp"] == "stamp_2" - assert "Content-Range" not in headers - - # Search and sort by PilotStamp in ascending order and PilotID in descending order - sort1 = SortSpec(parameter="PilotStamp", direction=SortDirection.ASC) - sort2 = SortSpec(parameter="PilotID", direction=SortDirection.DESC) - result, headers = await search([], [], [sort1, sort2]) - assert len(result) == N - assert result - assert result[0]["PilotStamp"] == "stamp_1" - assert result[0]["PilotID"] == 1 - assert result[99]["PilotStamp"] == "stamp_99" - assert result[99]["PilotID"] == 99 - assert "Content-Range" not in headers - - -async def test_search_pagination(search): - """Test that we can search for pilots.""" - # Search for the first 10 pilots - result, headers = await search([], [], [], per_page=10, page=1) - assert "Content-Range" in headers - # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" - total = int(headers["Content-Range"].split("/")[1]) - assert total == N - assert result - assert len(result) == 10 - assert result[0]["PilotID"] == 1 - - # Search for the second 10 pilots - result, headers = await search([], [], [], per_page=10, page=2) - assert "Content-Range" in headers - # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" - total = int(headers["Content-Range"].split("/")[1]) - assert total == N - assert result - assert len(result) == 10 - assert result[0]["PilotID"] == 11 - - # Search for the last 10 pilots - result, headers = await search([], [], [], per_page=10, page=10) - assert "Content-Range" in headers - # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" - total = int(headers["Content-Range"].split("/")[1]) - assert result - assert len(result) == 10 - assert result[0]["PilotID"] == 91 - - # Search for the second 50 pilots - result, headers = await search([], [], [], per_page=50, page=2) - assert "Content-Range" in headers - # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" - total = int(headers["Content-Range"].split("/")[1]) - assert result - assert len(result) == 50 - assert result[0]["PilotID"] == 51 - - # Invalid page number - result, headers = await search([], [], [], per_page=10, page=11) - assert "Content-Range" in headers - # Because Content-Range = f"pilots {first_idx}-{last_idx}/{total}" - total = int(headers["Content-Range"].split("/")[1]) - assert not result - - # Invalid page number - with pytest.raises(InvalidQueryError): - result = await search([], [], [], per_page=10, page=0) - - # Invalid per_page number - with pytest.raises(InvalidQueryError): - result = await search([], [], [], per_page=0, page=1) + assert r.status_code in (400, 422), r.json() diff --git a/docs/admin/explanations/pilots.md b/docs/admin/explanations/pilots.md new file mode 100644 index 000000000..4eaf82b05 --- /dev/null +++ b/docs/admin/explanations/pilots.md @@ -0,0 +1,11 @@ +# Pilots from the operator's perspective + +This page is for operators running a DiracX instance. For the developer view (identity model, lifecycle, cross-table search, extension points) see the [developer explanation](../../dev/explanations/pilots.md). + +## VO scoping + +Pilots are partitioned by VO. By default a user only sees and acts on pilots belonging to their own VO. Users holding the +`SERVICE_ADMINISTRATOR` security property bypass that filter and can read pilots across every VO through the same search and summary endpoints. + +Management actions (register, delete, patch metadata) are restricted to `SERVICE_ADMINISTRATOR`. Legacy X.509 pilot identities (`GENERIC_PILOT` property) may be permitted to self-register or self-modify on a per-route basis; those routes opt in via `allow_legacy_pilots=True` in the access policy and cap each call to a single pilot stamp as a containment measure against stolen credentials. +See [authentication with DiracX](auth-with-diracx.md) for the full pilot auth story (X.509, DiracX pilot tokens). diff --git a/docs/dev/explanations/pilots.md b/docs/dev/explanations/pilots.md index 544e12db5..6eb884acf 100644 --- a/docs/dev/explanations/pilots.md +++ b/docs/dev/explanations/pilots.md @@ -1,20 +1,58 @@ -## Presentation +# Pilots -Pilots are a piece of software that is running on *worker nodes*. There are two types of pilots: "DIRAC pilots", and "DiracX pilots". The first type corresponds to pilots with proxies, sent by DIRAC; and the second type corresponds to pilots with secrets. Both kinds will eventually interact with DiracX using tokens (DIRAC pilots by exchanging their proxies for tokens, DiracX by exchanging their secrets for tokens). +## What is a pilot -## Management +A pilot is a small piece of software that runs on a *worker node* and pulls user payloads (jobs). The same pilot binary is equipped to talk to both DIRAC and DiracX during the migration, and it supports two authentication modes: -Their management is adapted in DiracX, and each feature has its own route in DiracX. We will split the `/pilots` route into two parts: +- **X.509 proxy** (legacy): the pilot presents a proxy and, in DiracX, exchanges it for a DiracX token. Callers authenticated this way carry the `GENERIC_PILOT` property and are handled by the "legacy pilot" code paths in the access policy. +- **Pre-issued secret**: the pilot is provisioned with a secret that it exchanges for a DiracX token. Pilots authenticated this way are identified by their unique *stamp* rather than by a set of security properties. -1. `/api/pilots/*` to allow administrators and users to access and modify pilots -2. `/api/pilots/internal/*` is allocated for pilots resources: only DiracX pilots will have access to these resources +## Identity model -Each part has its own security policy: we want to prevent pilots to access users resources and vice-versa. To differentiate DIRAC pilots from users, we can get their token and compare their properties: `GENERIC_PILOT` is the property that defines a pilot. For DiracX pilots, we can differentiate them by looking at the token structure: they don't have properties, but a "stamp" (their identifier). +Three identifiers appear throughout the code and are easy to confuse: -## Endpoints +- `PilotStamp`: immutable string chosen by the pilot factory. Primary user-facing key; never changes for the lifetime of a pilot. +- `PilotID`: auto-incrementing database primary key. Not meaningful outside the DB layer; never exposed on the HTTP surface as an identity. +- `PilotJobReference`: the CE job reference (batch-system identifier) + that submitted the pilot process. Defaults to the stamp when not known. -We ordered our endpoints like so: +## Lifecycle -1. Creation: `POST /api/pilots/` -2. Deletion: `DELETE /api/pilots/` -3. Modification: `PATCH /api/pilots/metadata` +```mermaid +stateDiagram-v2 + [*] --> Submitted + Submitted --> Waiting + Waiting --> Running + Running --> Done + Running --> Failed + Submitted --> Aborted + Waiting --> Aborted + Running --> Aborted + Running --> Stalled + Stalled --> Running + Stalled --> Failed + Done --> [*] + Failed --> [*] + Aborted --> [*] + [*] --> Unknown + Unknown --> [*] +``` + +## Relationship to jobs + +A pilot can execute zero or more jobs over its lifetime. The association is tracked in the `JobToPilotMapping` table and is append-only: once a job has run on a pilot, the link is preserved until the pilot row is deleted. + +Both directions of the lookup are exposed as *pseudo-parameters* on the respective search endpoints. This keeps every pilot and job attribute addressable through a single `POST /search` per resource type, matching the UI's one-search-bar-per-resource mental model. The pattern mirrors the existing `LoggingInfo` pseudo-parameter on `POST /api/jobs/search`: the filter is intercepted in the logic layer, resolved against `JobToPilotMapping`, and rewritten into a normal vector filter before hitting the DB. + +- `POST /api/jobs/search` accepts a `PilotStamp` filter, resolved to a `JobID` filter via `JobToPilotMapping`. +- `POST /api/pilots/search` accepts a `JobID` filter, resolved to a `PilotID` filter. + +Concrete request bodies for both are provided as OpenAPI examples on the respective search routes; open the Swagger UI at `/api/docs` to see them. + +Only `eq` and `in` operators are supported on the pseudo-parameter; other operators (`neq`, `not in`, `lt`, ...) are refused with `InvalidQueryError` because their semantics across the join are ambiguous. Combining a `PilotStamp` filter with a `JobID` filter in the same request body is likewise refused; clients that want the intersection should compute it themselves. + +## VO scoping and authorization + +Pilots are partitioned by VO. By default a normal user sees and acts on pilots belonging to their own VO only. `SERVICE_ADMINISTRATOR` can read pilots across VOs via `/search` and `/summary`. + +Management actions (register, delete, patch metadata) require `SERVICE_ADMINISTRATOR`. Legacy X.509 pilot identities may be permitted to self-register or self-modify; those paths opt in via `allow_legacy_pilots=True` in the access policy and limit each call to a single pilot stamp as a containment measure against stolen credentials. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py index b19d47c68..a3640a1cf 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/_client.py @@ -21,11 +21,12 @@ JobsOperations, LollygagOperations, MyOperations, + PilotsOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.operations.LollygagOperations :ivar my: MyOperations operations :vartype my: _generated.operations.MyOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.my = MyOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request(self, request: HttpRequest, *, stream: bool = False, **kwargs: Any) -> HttpResponse: """Runs the network request through the client's chained policies. diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py index 32b9dad3a..8a4597b95 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/_client.py @@ -21,11 +21,12 @@ JobsOperations, LollygagOperations, MyOperations, + PilotsOperations, WellKnownOperations, ) -class Dirac: # pylint: disable=client-accepts-api-version-keyword +class Dirac: # pylint: disable=client-accepts-api-version-keyword,too-many-instance-attributes """Dirac. :ivar well_known: WellKnownOperations operations @@ -40,6 +41,8 @@ class Dirac: # pylint: disable=client-accepts-api-version-keyword :vartype lollygag: _generated.aio.operations.LollygagOperations :ivar my: MyOperations operations :vartype my: _generated.aio.operations.MyOperations + :ivar pilots: PilotsOperations operations + :vartype pilots: _generated.aio.operations.PilotsOperations :keyword endpoint: Service URL. Required. Default value is "". :paramtype endpoint: str """ @@ -78,6 +81,7 @@ def __init__( # pylint: disable=missing-client-constructor-parameter-credential self.jobs = JobsOperations(self._client, self._config, self._serialize, self._deserialize) self.lollygag = LollygagOperations(self._client, self._config, self._serialize, self._deserialize) self.my = MyOperations(self._client, self._config, self._serialize, self._deserialize) + self.pilots = PilotsOperations(self._client, self._config, self._serialize, self._deserialize) def send_request( self, request: HttpRequest, *, stream: bool = False, **kwargs: Any diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py index 5cfdf7253..6becb82b3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import MyOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "MyOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py index a2e0565c5..786045e19 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/aio/operations/_operations.py @@ -56,6 +56,11 @@ build_lollygag_insert_owner_object_request, build_my_pilots_get_pilot_summary_request, build_my_pilots_submit_pilot_request, + build_pilots_delete_pilots_request, + build_pilots_register_pilots_request, + build_pilots_search_request, + build_pilots_summary_request, + build_pilots_update_pilot_metadata_request, build_well_known_get_installation_metadata_request, build_well_known_get_jwks_request, build_well_known_get_openid_configuration_request, @@ -1958,6 +1963,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. @@ -1998,6 +2009,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: IO[bytes] :keyword page: Default value is 1. @@ -2037,6 +2054,12 @@ async def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. @@ -2605,3 +2628,546 @@ async def pilots_get_pilot_summary(self, **kwargs: Any) -> dict[str, int]: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.aio.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: AsyncPipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + async def register_pilots( + self, body: _models.BodyPilotsRegisterPilots, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Register Pilots. + + Register a batch of pilots with their references. + + If any stamp already exists, the whole batch is rejected with a 409. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsRegisterPilots + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def register_pilots(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Register Pilots. + + Register a batch of pilots with their references. + + If any stamp already exists, the whole batch is rejected with a 409. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def register_pilots(self, body: Union[_models.BodyPilotsRegisterPilots, IO[bytes]], **kwargs: Any) -> Any: + """Register Pilots. + + Register a batch of pilots with their references. + + If any stamp already exists, the whole batch is rejected with a 409. + + :param body: Is either a BodyPilotsRegisterPilots type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsRegisterPilots or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsRegisterPilots") + + _request = build_pilots_register_pilots_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace_async + async def delete_pilots(self, *, pilot_stamps: list[str], **kwargs: Any) -> None: + """Delete Pilots. + + Delete pilots by stamp. + + Deletes the pilot rows as well as their logs and job associations. + + Age-based retention cleanup is deliberately *not* exposed here: it is + handled by the maintenance task worker. See + ``diracx.logic.pilots.management.delete_pilots``. + + :keyword pilot_stamps: Stamps of the pilots to delete. Required. + :paramtype pilot_stamps: list[str] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def update_pilot_metadata( + self, body: _models.BodyPilotsUpdatePilotMetadata, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Metadata. + + Update pilot metadata (status, benchmark, etc.). + + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def update_pilot_metadata( + self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Metadata. + + Update pilot metadata (status, benchmark, etc.). + + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def update_pilot_metadata( + self, body: Union[_models.BodyPilotsUpdatePilotMetadata, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Metadata. + + Update pilot metadata (status, benchmark, etc.). + + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. + + :param body: Is either a BodyPilotsUpdatePilotMetadata type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotMetadata") + + _request = build_pilots_update_pilot_metadata_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + async def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> list[dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> list[dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> list[dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + content_type = content_type if body else None + cls: ClsType[list[dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" if body else None + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + async def summary( + self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Summary. + + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + async def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace_async + async def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = await self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py index 25803aefc..7d0d5faef 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/__init__.py @@ -16,8 +16,8 @@ BodyAuthRevokeRefreshTokenByRefreshToken, BodyJobsRescheduleJobs, BodyJobsUnassignBulkJobsSandboxes, - BodyPilotsAddPilotStamps, - BodyPilotsUpdatePilotFields, + BodyPilotsRegisterPilots, + BodyPilotsUpdatePilotMetadata, ExtendedMetadata, GroupInfo, HTTPValidationError, @@ -28,7 +28,7 @@ JobMetaData, JobStatusUpdate, OpenIDConfiguration, - PilotFieldsMapping, + PilotMetadata, SandboxDownloadResponse, SandboxInfo, SandboxUploadResponse, @@ -67,8 +67,8 @@ "BodyAuthRevokeRefreshTokenByRefreshToken", "BodyJobsRescheduleJobs", "BodyJobsUnassignBulkJobsSandboxes", - "BodyPilotsAddPilotStamps", - "BodyPilotsUpdatePilotFields", + "BodyPilotsRegisterPilots", + "BodyPilotsUpdatePilotMetadata", "ExtendedMetadata", "GroupInfo", "HTTPValidationError", @@ -79,7 +79,7 @@ "JobMetaData", "JobStatusUpdate", "OpenIDConfiguration", - "PilotFieldsMapping", + "PilotMetadata", "SandboxDownloadResponse", "SandboxInfo", "SandboxUploadResponse", diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py index 2e88298bb..8b1c0c229 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/models/_models.py @@ -184,12 +184,12 @@ def __init__(self, *, job_ids: list[int], **kwargs: Any) -> None: self.job_ids = job_ids -class BodyPilotsAddPilotStamps(_serialization.Model): - """Body_pilots_add_pilot_stamps. +class BodyPilotsRegisterPilots(_serialization.Model): + """Body_pilots_register_pilots. All required parameters must be populated in order to send to server. - :ivar pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :ivar pilot_stamps: Stamps of the pilots to create. Required. :vartype pilot_stamps: list[str] :ivar vo: Pilot virtual organization. Required. :vartype vo: str @@ -201,8 +201,8 @@ class BodyPilotsAddPilotStamps(_serialization.Model): :vartype destination_site: str :ivar pilot_references: Association of a pilot reference with a pilot stamp. :vartype pilot_references: dict[str, str] - :ivar pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", "Running", - "Done", "Failed", "Deleted", "Aborted", and "Unknown". + :ivar pilot_status: Initial status of the pilots. Known values are: "Submitted", "Waiting", + "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :vartype pilot_status: str or ~_generated.models.PilotStatus """ @@ -234,7 +234,7 @@ def __init__( **kwargs: Any ) -> None: """ - :keyword pilot_stamps: List of the pilot stamps we want to add to the db. Required. + :keyword pilot_stamps: Stamps of the pilots to create. Required. :paramtype pilot_stamps: list[str] :keyword vo: Pilot virtual organization. Required. :paramtype vo: str @@ -246,7 +246,7 @@ def __init__( :paramtype destination_site: str :keyword pilot_references: Association of a pilot reference with a pilot stamp. :paramtype pilot_references: dict[str, str] - :keyword pilot_status: Status of the pilots. Known values are: "Submitted", "Waiting", + :keyword pilot_status: Initial status of the pilots. Known values are: "Submitted", "Waiting", "Running", "Done", "Failed", "Deleted", "Aborted", and "Unknown". :paramtype pilot_status: str or ~_generated.models.PilotStatus """ @@ -260,31 +260,30 @@ def __init__( self.pilot_status = pilot_status -class BodyPilotsUpdatePilotFields(_serialization.Model): - """Body_pilots_update_pilot_fields. +class BodyPilotsUpdatePilotMetadata(_serialization.Model): + """Body_pilots_update_pilot_metadata. All required parameters must be populated in order to send to server. - :ivar pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. Required. - :vartype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + :ivar pilot_metadata: Pilot metadata mappings to apply. Required. + :vartype pilot_metadata: list[~_generated.models.PilotMetadata] """ _validation = { - "pilot_stamps_to_fields_mapping": {"required": True}, + "pilot_metadata": {"required": True}, } _attribute_map = { - "pilot_stamps_to_fields_mapping": {"key": "pilot_stamps_to_fields_mapping", "type": "[PilotFieldsMapping]"}, + "pilot_metadata": {"key": "pilot_metadata", "type": "[PilotMetadata]"}, } - def __init__(self, *, pilot_stamps_to_fields_mapping: list["_models.PilotFieldsMapping"], **kwargs: Any) -> None: + def __init__(self, *, pilot_metadata: list["_models.PilotMetadata"], **kwargs: Any) -> None: """ - :keyword pilot_stamps_to_fields_mapping: (pilot_stamp, pilot_fields) mapping to change. - Required. - :paramtype pilot_stamps_to_fields_mapping: list[~_generated.models.PilotFieldsMapping] + :keyword pilot_metadata: Pilot metadata mappings to apply. Required. + :paramtype pilot_metadata: list[~_generated.models.PilotMetadata] """ super().__init__(**kwargs) - self.pilot_stamps_to_fields_mapping = pilot_stamps_to_fields_mapping + self.pilot_metadata = pilot_metadata class ExtendedMetadata(_serialization.Model): @@ -1053,31 +1052,34 @@ def __init__( self.code_challenge_methods_supported = code_challenge_methods_supported -class PilotFieldsMapping(_serialization.Model): - """All the fields that a user can modify on a Pilot (except PilotStamp). +class PilotMetadata(_serialization.Model): + """Mutable metadata attached to a pilot. + + ``PilotStamp`` identifies the pilot and cannot be changed. Every other + field is optional; when absent it is left untouched by an update. All required parameters must be populated in order to send to server. - :ivar pilot_stamp: Pilotstamp. Required. + :ivar pilot_stamp: Immutable stamp identifying the pilot. Required. :vartype pilot_stamp: str - :ivar status_reason: Statusreason. + :ivar status_reason: Human-readable reason for the current status. :vartype status_reason: str - :ivar status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", - "Failed", "Deleted", "Aborted", and "Unknown". + :ivar status: Current pilot status. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". :vartype status: str or ~_generated.models.PilotStatus - :ivar bench_mark: Benchmark. + :ivar bench_mark: Pilot benchmark value. :vartype bench_mark: float - :ivar destination_site: Destinationsite. + :ivar destination_site: Destination site. :vartype destination_site: str - :ivar queue: Queue. + :ivar queue: Batch queue name. :vartype queue: str - :ivar grid_site: Gridsite. + :ivar grid_site: Grid site. :vartype grid_site: str - :ivar grid_type: Gridtype. + :ivar grid_type: Grid type. :vartype grid_type: str - :ivar accounting_sent: Accountingsent. + :ivar accounting_sent: Whether accounting has been sent for this pilot. :vartype accounting_sent: bool - :ivar current_job_id: Currentjobid. + :ivar current_job_id: ID of the job currently running on this pilot. :vartype current_job_id: int """ @@ -1114,26 +1116,26 @@ def __init__( **kwargs: Any ) -> None: """ - :keyword pilot_stamp: Pilotstamp. Required. + :keyword pilot_stamp: Immutable stamp identifying the pilot. Required. :paramtype pilot_stamp: str - :keyword status_reason: Statusreason. + :keyword status_reason: Human-readable reason for the current status. :paramtype status_reason: str - :keyword status: PilotStatus. Known values are: "Submitted", "Waiting", "Running", "Done", - "Failed", "Deleted", "Aborted", and "Unknown". + :keyword status: Current pilot status. Known values are: "Submitted", "Waiting", "Running", + "Done", "Failed", "Deleted", "Aborted", and "Unknown". :paramtype status: str or ~_generated.models.PilotStatus - :keyword bench_mark: Benchmark. + :keyword bench_mark: Pilot benchmark value. :paramtype bench_mark: float - :keyword destination_site: Destinationsite. + :keyword destination_site: Destination site. :paramtype destination_site: str - :keyword queue: Queue. + :keyword queue: Batch queue name. :paramtype queue: str - :keyword grid_site: Gridsite. + :keyword grid_site: Grid site. :paramtype grid_site: str - :keyword grid_type: Gridtype. + :keyword grid_type: Grid type. :paramtype grid_type: str - :keyword accounting_sent: Accountingsent. + :keyword accounting_sent: Whether accounting has been sent for this pilot. :paramtype accounting_sent: bool - :keyword current_job_id: Currentjobid. + :keyword current_job_id: ID of the job currently running on this pilot. :paramtype current_job_id: int """ super().__init__(**kwargs) diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py index 5cfdf7253..6becb82b3 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/__init__.py @@ -16,6 +16,7 @@ from ._operations import JobsOperations # type: ignore from ._operations import LollygagOperations # type: ignore from ._operations import MyOperations # type: ignore +from ._operations import PilotsOperations # type: ignore from ._patch import __all__ as _patch_all from ._patch import * @@ -28,6 +29,7 @@ "JobsOperations", "LollygagOperations", "MyOperations", + "PilotsOperations", ] __all__.extend([p for p in _patch_all if p not in __all__]) # pyright: ignore _patch_sdk() diff --git a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py index 7dcaa92ee..c381967e8 100644 --- a/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py +++ b/extensions/gubbins/gubbins-client/src/gubbins/client/_generated/operations/_operations.py @@ -647,6 +647,90 @@ def build_my_pilots_get_pilot_summary_request(**kwargs: Any) -> HttpRequest: # return HttpRequest(method="GET", url=_url, headers=_headers, **kwargs) +def build_pilots_register_pilots_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + +def build_pilots_delete_pilots_request(*, pilot_stamps: list[str], **kwargs: Any) -> HttpRequest: + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + # Construct URL + _url = "/api/pilots/" + + # Construct parameters + _params["pilot_stamps"] = _SERIALIZER.query("pilot_stamps", pilot_stamps, "[str]") + + return HttpRequest(method="DELETE", url=_url, params=_params, **kwargs) + + +def build_pilots_update_pilot_metadata_request(**kwargs: Any) -> HttpRequest: # pylint: disable=name-too-long + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + # Construct URL + _url = "/api/pilots/metadata" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + + return HttpRequest(method="PATCH", url=_url, headers=_headers, **kwargs) + + +def build_pilots_search_request(*, page: int = 1, per_page: int = 100, **kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = case_insensitive_dict(kwargs.pop("params", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/search" + + # Construct parameters + if page is not None: + _params["page"] = _SERIALIZER.query("page", page, "int", minimum=1) + if per_page is not None: + _params["per_page"] = _SERIALIZER.query("per_page", per_page, "int", maximum=10000, minimum=1) + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, params=_params, headers=_headers, **kwargs) + + +def build_pilots_summary_request(**kwargs: Any) -> HttpRequest: + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + accept = _headers.pop("Accept", "application/json") + + # Construct URL + _url = "/api/pilots/summary" + + # Construct headers + if content_type is not None: + _headers["Content-Type"] = _SERIALIZER.header("content_type", content_type, "str") + _headers["Accept"] = _SERIALIZER.header("accept", accept, "str") + + return HttpRequest(method="POST", url=_url, headers=_headers, **kwargs) + + class WellKnownOperations: """ .. warning:: @@ -2536,6 +2620,12 @@ def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: ~_generated.models.SearchParams :keyword page: Default value is 1. @@ -2576,6 +2666,12 @@ def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Default value is None. :type body: IO[bytes] :keyword page: Default value is 1. @@ -2615,6 +2711,12 @@ def search( By default, the search will return all jobs the user has access to, and all the fields of the job will be returned. + A ``PilotStamp`` pseudo-parameter is also accepted in the ``search`` + filter list (operators ``eq`` / ``in`` only): it is transparently + resolved through ``JobToPilotMapping`` into a ``JobID`` filter, + allowing callers to ask "jobs run by this pilot" through the same + endpoint. + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. :type body: ~_generated.models.SearchParams or IO[bytes] :keyword page: Default value is 1. @@ -3181,3 +3283,544 @@ def pilots_get_pilot_summary(self, **kwargs: Any) -> dict[str, int]: return cls(pipeline_response, deserialized, {}) # type: ignore return deserialized # type: ignore + + +class PilotsOperations: + """ + .. warning:: + **DO NOT** instantiate this class directly. + + Instead, you should access the following operations through + :class:`~_generated.Dirac`'s + :attr:`pilots` attribute. + """ + + models = _models + + def __init__(self, *args, **kwargs) -> None: + input_args = list(args) + self._client: PipelineClient = input_args.pop(0) if input_args else kwargs.pop("client") + self._config: DiracConfiguration = input_args.pop(0) if input_args else kwargs.pop("config") + self._serialize: Serializer = input_args.pop(0) if input_args else kwargs.pop("serializer") + self._deserialize: Deserializer = input_args.pop(0) if input_args else kwargs.pop("deserializer") + + @overload + def register_pilots( + self, body: _models.BodyPilotsRegisterPilots, *, content_type: str = "application/json", **kwargs: Any + ) -> Any: + """Register Pilots. + + Register a batch of pilots with their references. + + If any stamp already exists, the whole batch is rejected with a 409. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsRegisterPilots + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def register_pilots(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Register Pilots. + + Register a batch of pilots with their references. + + If any stamp already exists, the whole batch is rejected with a 409. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def register_pilots(self, body: Union[_models.BodyPilotsRegisterPilots, IO[bytes]], **kwargs: Any) -> Any: + """Register Pilots. + + Register a batch of pilots with their references. + + If any stamp already exists, the whole batch is rejected with a 409. + + :param body: Is either a BodyPilotsRegisterPilots type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsRegisterPilots or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsRegisterPilots") + + _request = build_pilots_register_pilots_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore + + @distributed_trace + def delete_pilots( # pylint: disable=inconsistent-return-statements + self, *, pilot_stamps: list[str], **kwargs: Any + ) -> None: + """Delete Pilots. + + Delete pilots by stamp. + + Deletes the pilot rows as well as their logs and job associations. + + Age-based retention cleanup is deliberately *not* exposed here: it is + handled by the maintenance task worker. See + ``diracx.logic.pilots.management.delete_pilots``. + + :keyword pilot_stamps: Stamps of the pilots to delete. Required. + :paramtype pilot_stamps: list[str] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = kwargs.pop("headers", {}) or {} + _params = kwargs.pop("params", {}) or {} + + cls: ClsType[None] = kwargs.pop("cls", None) + + _request = build_pilots_delete_pilots_request( + pilot_stamps=pilot_stamps, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def update_pilot_metadata( + self, body: _models.BodyPilotsUpdatePilotMetadata, *, content_type: str = "application/json", **kwargs: Any + ) -> None: + """Update Pilot Metadata. + + Update pilot metadata (status, benchmark, etc.). + + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. + + :param body: Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def update_pilot_metadata(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> None: + """Update Pilot Metadata. + + Update pilot metadata (status, benchmark, etc.). + + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def update_pilot_metadata( # pylint: disable=inconsistent-return-statements + self, body: Union[_models.BodyPilotsUpdatePilotMetadata, IO[bytes]], **kwargs: Any + ) -> None: + """Update Pilot Metadata. + + Update pilot metadata (status, benchmark, etc.). + + Only fields defined in ``PilotMetadata`` are mutable. ``PilotStamp`` + identifies the row and cannot be changed. + + :param body: Is either a BodyPilotsUpdatePilotMetadata type or a IO[bytes] type. Required. + :type body: ~_generated.models.BodyPilotsUpdatePilotMetadata or IO[bytes] + :return: None + :rtype: None + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[None] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "BodyPilotsUpdatePilotMetadata") + + _request = build_pilots_update_pilot_metadata_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [204]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + if cls: + return cls(pipeline_response, None, {}) # type: ignore + + @overload + def search( + self, + body: Optional[_models.SearchParams] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> list[dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + + :param body: Default value is None. + :type body: ~_generated.models.SearchParams + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def search( + self, + body: Optional[IO[bytes]] = None, + *, + page: int = 1, + per_page: int = 100, + content_type: str = "application/json", + **kwargs: Any + ) -> list[dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + + :param body: Default value is None. + :type body: IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def search( + self, + body: Optional[Union[_models.SearchParams, IO[bytes]]] = None, + *, + page: int = 1, + per_page: int = 100, + **kwargs: Any + ) -> list[dict[str, Any]]: + """Search. + + Retrieve information about pilots. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + A ``JobID`` pseudo-parameter is also accepted in the ``search`` filter + list (operators ``eq`` / ``in`` only): it is transparently resolved + through ``JobToPilotMapping`` into a ``PilotID`` filter, allowing + callers to ask "pilots that ran this job" through the same endpoint. + + :param body: Is either a SearchParams type or a IO[bytes] type. Default value is None. + :type body: ~_generated.models.SearchParams or IO[bytes] + :keyword page: Default value is 1. + :paramtype page: int + :keyword per_page: Default value is 100. + :paramtype per_page: int + :return: list of dict mapping str to any + :rtype: list[dict[str, any]] + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + content_type = content_type if body else None + cls: ClsType[list[dict[str, Any]]] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" if body else None + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + if body is not None: + _json = self._serialize.body(body, "SearchParams") + else: + _json = None + + _request = build_pilots_search_request( + page=page, + per_page=per_page, + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200, 206]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + response_headers = {} + if response.status_code == 206: + response_headers["Content-Range"] = self._deserialize("str", response.headers.get("Content-Range")) + + deserialized = self._deserialize("[{object}]", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, response_headers) # type: ignore + + return deserialized # type: ignore + + @overload + def summary(self, body: _models.SummaryParams, *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + :param body: Required. + :type body: ~_generated.models.SummaryParams + :keyword content_type: Body Parameter content-type. Content type parameter for JSON body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @overload + def summary(self, body: IO[bytes], *, content_type: str = "application/json", **kwargs: Any) -> Any: + """Summary. + + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + :param body: Required. + :type body: IO[bytes] + :keyword content_type: Body Parameter content-type. Content type parameter for binary body. + Default value is "application/json". + :paramtype content_type: str + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + + @distributed_trace + def summary(self, body: Union[_models.SummaryParams, IO[bytes]], **kwargs: Any) -> Any: + """Summary. + + Aggregate pilot counts suitable for plotting. + + Normal users see only their own VO's pilots. Service administrators see + pilots from all VOs. + + :param body: Is either a SummaryParams type or a IO[bytes] type. Required. + :type body: ~_generated.models.SummaryParams or IO[bytes] + :return: any + :rtype: any + :raises ~azure.core.exceptions.HttpResponseError: + """ + error_map: MutableMapping = { + 401: ClientAuthenticationError, + 404: ResourceNotFoundError, + 409: ResourceExistsError, + 304: ResourceNotModifiedError, + } + error_map.update(kwargs.pop("error_map", {}) or {}) + + _headers = case_insensitive_dict(kwargs.pop("headers", {}) or {}) + _params = kwargs.pop("params", {}) or {} + + content_type: Optional[str] = kwargs.pop("content_type", _headers.pop("Content-Type", None)) + cls: ClsType[Any] = kwargs.pop("cls", None) + + content_type = content_type or "application/json" + _json = None + _content = None + if isinstance(body, (IOBase, bytes)): + _content = body + else: + _json = self._serialize.body(body, "SummaryParams") + + _request = build_pilots_summary_request( + content_type=content_type, + json=_json, + content=_content, + headers=_headers, + params=_params, + ) + _request.url = self._client.format_url(_request.url) + + _stream = False + pipeline_response: PipelineResponse = self._client._pipeline.run( # pylint: disable=protected-access + _request, stream=_stream, **kwargs + ) + + response = pipeline_response.http_response + + if response.status_code not in [200]: + map_error(status_code=response.status_code, response=response, error_map=error_map) + raise HttpResponseError(response=response) + + deserialized = self._deserialize("object", pipeline_response.http_response) + + if cls: + return cls(pipeline_response, deserialized, {}) # type: ignore + + return deserialized # type: ignore diff --git a/extensions/gubbins/gubbins-routers/tests/test_gubbins_job_manager.py b/extensions/gubbins/gubbins-routers/tests/test_gubbins_job_manager.py index b26b5482f..a586de84e 100644 --- a/extensions/gubbins/gubbins-routers/tests/test_gubbins_job_manager.py +++ b/extensions/gubbins/gubbins-routers/tests/test_gubbins_job_manager.py @@ -23,6 +23,8 @@ "ConfigSource", "TaskQueueDB", "DevelopmentSettings", + "PilotAgentsDB", + "PilotManagementAccessPolicy", ] ) diff --git a/pixi.lock b/pixi.lock index c2ff7efa8..5c48bd9a1 100644 --- a/pixi.lock +++ b/pixi.lock @@ -19241,8 +19241,8 @@ packages: requires_python: '>=3.11' - pypi: ./ name: diracx - version: 0.0.13.dev10+g09d7149dd.d20260414 - sha256: 1f78b10647ef5e2e13a5438ef3c7f2ac2c051773bf180e30f00fdc45a328425f + version: 0.0.14.dev9+gd8921d974 + sha256: 410903a3be93f06d98b9df3cd204f3a92c585c3f4424516134f6b7272630536b requires_dist: - diracx-api - diracx-cli @@ -19378,7 +19378,7 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-routers name: diracx-routers - version: 0.0.13.dev10+g09d7149dd.d20260414 + version: 0.0.14.dev9+gd8921d974 sha256: ef0c49134e20b3a5232131ec53931179466d3f8db3a16b846660ff39a4978acc requires_dist: - cachetools @@ -19409,8 +19409,8 @@ packages: requires_python: '>=3.11' - pypi: ./diracx-tasks name: diracx-tasks - version: 0.0.13.dev10+g09d7149dd.d20260414 - sha256: 752189cc698d17c76e8c240a4a82f6593a5dc040545c29a504753de135b1bd6b + version: 0.0.14.dev9+gd8921d974 + sha256: 47ecbf1d4db5442abf0cb47de03b0c0ce7064cfb0f9ff9bf009c9a2041d02db1 requires_dist: - croniter - diracx-core