Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 39 additions & 5 deletions packages/opal-server/opal_server/git_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,10 +190,21 @@ async def fetch_and_notify_on_changes(
GitPolicyFetcher.repos_last_fetched[
self.source_id
] = datetime.datetime.now()
await run_sync(
repo.remotes[self._remote].fetch,
callbacks=self._auth_callbacks,
)
try:
await run_sync(
repo.remotes[self._remote].fetch,
callbacks=self._auth_callbacks,
)
except pygit2.GitError:
logger.exception(
"Failed to fetch remote {remote} for {url}, "
"cleaning up invalid repo",
remote=self._remote,
url=self._source.url,
)
self._invalidate_repo_cache()
self._cleanup_repo_path(self._repo_path)
return
logger.debug(f"Fetch completed: {self._source.url}")

# New commits might be present because of a previous fetch made by another scope
Expand All @@ -204,7 +215,8 @@ async def fetch_and_notify_on_changes(
logger.warning(
"Deleting invalid repo: {path}", path=self._repo_path
)
shutil.rmtree(self._repo_path)
self._invalidate_repo_cache()
self._cleanup_repo_path(self._repo_path)
else:
logger.info("Repo not found at {path}", path=self._repo_path)

Expand All @@ -215,6 +227,26 @@ def _discover_repository(self, path: Path) -> bool:
git_path: Path = path / ".git"
return discover_repository(str(path)) and git_path.exists()

@staticmethod
def _cleanup_repo_path(repo_path: Path):
"""Safely remove a repository path, handling broken symlinks and
partial directories left behind by failed operations."""
path = Path(repo_path)
if path.is_symlink() or path.exists():
logger.info(
"Cleaning up repo path: {path}",
path=repo_path,
)
shutil.rmtree(str(repo_path), ignore_errors=True)

def _invalidate_repo_cache(self):
"""Remove this repo from the class-level caches if present."""
path = str(self._repo_path)
GitPolicyFetcher.repos.pop(path, None)
GitPolicyFetcher.repos_last_fetched.pop(
GitPolicyFetcher.source_id(self._source), None
)

async def _clone(self):
logger.info(
"Cloning repo at '{url}' to '{path}'",
Expand All @@ -230,6 +262,8 @@ async def _clone(self):
)
except pygit2.GitError:
logger.exception(f"Could not clone repo at {self._source.url}")
self._cleanup_repo_path(self._repo_path)
self._invalidate_repo_cache()
else:
logger.info(f"Clone completed: {self._source.url}")
await self._notify_on_changes(repo)
Expand Down
42 changes: 41 additions & 1 deletion packages/opal-server/opal_server/pubsub.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import socket
import time
from contextlib import contextmanager
from contextvars import ContextVar
from threading import Lock
from typing import Dict, Generator, List, Optional, Set, Tuple, Union, cast
from urllib.parse import urlparse
from uuid import UUID, uuid4

from fastapi import APIRouter, Depends, WebSocket
Expand Down Expand Up @@ -141,7 +143,12 @@ def __init__(self, signer: JWTSigner, broadcaster_uri: str = None):

self.broadcaster = None
if broadcaster_uri is not None:
logger.info(f"Initializing broadcaster for server<->server communication")
safe_uri = self._mask_uri_password(broadcaster_uri)
logger.info(
"Initializing broadcaster for server<->server communication, uri={uri}",
uri=safe_uri,
)
self._validate_broadcast_uri(broadcaster_uri)
self.broadcaster = EventBroadcaster(
broadcaster_uri,
notifier=self.notifier,
Expand Down Expand Up @@ -202,6 +209,39 @@ async def websocket_rpc_endpoint(
finally:
await websocket.close()

@staticmethod
def _mask_uri_password(uri: str) -> str:
"""Return the URI with password masked."""
parsed = urlparse(uri)
if parsed.password:
masked = parsed._replace(
netloc=f"{parsed.username}:***@{parsed.hostname}"
+ (f":{parsed.port}" if parsed.port else "")
)
return masked.geturl()
return uri

@staticmethod
def _validate_broadcast_uri(uri: str):
"""Validate the broadcast URI hostname can be resolved. Logs ERROR on
failure but does not raise."""
parsed = urlparse(uri)
hostname = parsed.hostname
if not hostname:
logger.error(
"Broadcast URI has no hostname: {uri}. Check your BROADCAST_URI configuration.",
uri=uri,
)
return
try:
socket.getaddrinfo(hostname, None)
except socket.gaierror as e:
logger.error(
"Cannot resolve broadcast URI hostname '{hostname}': {error}. Check your BROADCAST_URI configuration.",
hostname=hostname,
error=str(e),
)

@staticmethod
async def _verify_permitted_topics(
topics: Union[TopicList, ALL_TOPICS], channel: RpcChannel
Expand Down
17 changes: 16 additions & 1 deletion packages/opal-server/opal_server/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ async def start_server_background_tasks(self):
await self.broadcast_listening_context.__aenter__()
# if the broadcast channel is closed, we want to restart worker process because statistics can't be reliable anymore
self.broadcast_listening_context._event_broadcaster.get_reader_task().add_done_callback(
lambda _: self._graceful_shutdown()
self._on_broadcaster_disconnected
)
asyncio.create_task(self.opal_statistics.run())
self.pubsub.endpoint.notifier.register_unsubscribe_event(
Expand Down Expand Up @@ -398,6 +398,21 @@ async def stop_server_background_tasks(self):
except Exception:
logger.exception("exception while shutting down background tasks")

def _on_broadcaster_disconnected(self, task: asyncio.Task):
"""Callback when the broadcast listener task completes unexpectedly."""
try:
exc = task.exception()
except (asyncio.CancelledError, asyncio.InvalidStateError):
exc = None
if exc is not None:
logger.error(
"Broadcast channel connection failed: {error}. Check BROADCAST_URI configuration.",
error=exc,
)
else:
logger.warning("Broadcast channel disconnected.")
self._graceful_shutdown()

def _graceful_shutdown(self):
logger.info("Trigger worker graceful shutdown")
os.kill(os.getpid(), signal.SIGTERM)
77 changes: 77 additions & 0 deletions packages/opal-server/tests/test_broadcast_uri_validation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Tests for broadcast URI validation and error logging (issue #716)."""

import asyncio
import logging
import socket
from unittest.mock import MagicMock, patch

import pytest

from opal_server.pubsub import PubSub


class TestValidateBroadcastUri:
def test_invalid_hostname_logs_error(self, caplog):
"""Unresolvable hostname should be logged at ERROR level."""
with caplog.at_level(logging.ERROR, logger="opal.server"):
with patch(
"opal_server.pubsub.socket.getaddrinfo",
side_effect=socket.gaierror("Name or service not known"),
):
PubSub._validate_broadcast_uri("postgres://invalid-host-xyz:5432/db")
assert any("Cannot resolve broadcast URI hostname" in r.message for r in caplog.records)

def test_valid_hostname_no_error(self, caplog):
"""Resolvable hostname should not produce ERROR logs."""
with caplog.at_level(logging.ERROR, logger="opal.server"):
with patch("opal_server.pubsub.socket.getaddrinfo", return_value=[(None,) * 5]):
PubSub._validate_broadcast_uri("postgres://localhost:5432/db")
assert not any("Cannot resolve" in r.message for r in caplog.records)

def test_uri_without_hostname_logs_error(self, caplog):
"""URI with no hostname should be logged at ERROR level."""
with caplog.at_level(logging.ERROR, logger="opal.server"):
PubSub._validate_broadcast_uri("postgres://")
assert any("has no hostname" in r.message for r in caplog.records)


class TestMaskUriPassword:
def test_password_masked(self):
result = PubSub._mask_uri_password("postgres://user:secret@host:5432/db")
assert "secret" not in result
assert "***" in result
assert "user" in result

def test_no_password(self):
uri = "postgres://host:5432/db"
assert PubSub._mask_uri_password(uri) == uri


class TestOnBroadcasterDisconnected:
def test_callback_logs_exception_on_error(self, caplog):
"""done_callback should log task exceptions at ERROR level."""
from opal_server.server import OpalServer

task = MagicMock(spec=asyncio.Task)
task.exception.return_value = ConnectionError("connection refused")

server = object.__new__(OpalServer)

with caplog.at_level(logging.ERROR, logger="opal.server"):
with patch.object(OpalServer, "_graceful_shutdown"):
server._on_broadcaster_disconnected(task)

assert any("Broadcast channel connection failed" in r.message for r in caplog.records)

def test_callback_calls_graceful_shutdown(self):
"""done_callback should call _graceful_shutdown."""
from opal_server.server import OpalServer

task = MagicMock(spec=asyncio.Task)
task.exception.return_value = None

server = object.__new__(OpalServer)

with patch.object(OpalServer, "_graceful_shutdown") as mock_shutdown:
server._on_broadcaster_disconnected(task)
mock_shutdown.assert_called_once()
Loading