Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions src/amp/loaders/implementations/_postgres_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,31 @@
from pyarrow import csv


def _quote_identifier(name: str) -> str:
"""Return a safely double-quoted PostgreSQL identifier.

Wraps the name in double quotes and escapes any embedded double quotes
by doubling them (standard SQL identifier quoting rules). This prevents
syntax errors when column or table names collide with PostgreSQL reserved
keywords such as `to`, `end`, `index`, etc.

Args:
name: Raw identifier (column or table name)

Returns:
Double-quoted identifier safe for use in SQL DDL/DML statements

Examples:
>>> _quote_identifier('to')
'"to"'
>>> _quote_identifier('block_num')
'"block_num"'
>>> _quote_identifier('weird"name')
'"weird\"\"name"'
"""
return '"' + name.replace('"', '""') + '"'


def prepare_csv_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[io.StringIO, List[str]]:
"""
Convert Arrow data to CSV format optimized for PostgreSQL COPY.
Expand Down Expand Up @@ -43,7 +68,8 @@ def prepare_csv_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[io.StringIO

csv_buffer = io.StringIO(csv_data)

# Get column names from Arrow schema
# Get column names from Arrow schema (raw, unquoted — quoting is the
# caller's responsibility when constructing SQL identifiers)
column_names = [field.name for field in data.schema]

return csv_buffer, column_names
Expand Down Expand Up @@ -103,12 +129,14 @@ def prepare_insert_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[str, Lis
# Convert Arrow data to Python objects
data_dict = data.to_pydict()

# Get column names
# Get column names and quote each one so reserved keywords (e.g. `to`,
# `end`, `index`) do not cause syntax errors in the INSERT statement.
column_names = [field.name for field in data.schema]
quoted_columns = [_quote_identifier(c) for c in column_names]

# Create INSERT statement template
placeholders = ', '.join(['%s'] * len(column_names))
insert_sql = f'({", ".join(column_names)}) VALUES ({placeholders})'
insert_sql = f'({", ".join(quoted_columns)}) VALUES ({placeholders})'

# Prepare data for insertion
rows = []
Expand Down
71 changes: 71 additions & 0 deletions tests/unit/test_postgres_helpers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
"""
Unit tests for PostgreSQL SQL identifier quoting in _postgres_helpers.py.

Verifies that reserved keyword column names (e.g. 'to', 'from') are double-quoted in
generated INSERT and COPY SQL to prevent syntax errors.
"""

import pyarrow as pa
import pytest

from amp.loaders.implementations._postgres_helpers import prepare_insert_data


@pytest.fixture
def eth_tx_batch():
"""Arrow RecordBatch modelling Ethereum transaction data with reserved keyword column names."""
schema = pa.schema(
[
pa.field('block_hash', pa.binary(32), nullable=False),
pa.field('block_num', pa.uint64(), nullable=False),
pa.field('tx_index', pa.uint32(), nullable=False),
pa.field('tx_hash', pa.binary(32), nullable=False),
pa.field('to', pa.binary(20), nullable=True), # reserved keyword; nullable for contract creation
pa.field('nonce', pa.uint64(), nullable=False),
pa.field('value', pa.decimal128(38, 0), nullable=False),
pa.field('from', pa.binary(20), nullable=False), # reserved keyword
]
)
return pa.RecordBatch.from_arrays(
[
pa.array([b'\x01' * 32, b'\x02' * 32], type=pa.binary(32)),
pa.array([18_000_000, 18_000_001], type=pa.uint64()),
pa.array([0, 1], type=pa.uint32()),
pa.array([b'\x03' * 32, b'\x04' * 32], type=pa.binary(32)),
pa.array([b'\xaa' * 20, None], type=pa.binary(20)), # None = contract creation tx
pa.array([0, 1], type=pa.uint64()),
pa.array([1_000_000_000, 2_000_000_000], type=pa.decimal128(38, 0)),
pa.array([b'\xbb' * 20, b'\xcc' * 20], type=pa.binary(20)),
],
schema=schema,
)


@pytest.mark.unit
class TestInsertSqlIdentifierQuoting:
"""
prepare_insert_data() must double-quote all column names in the generated
INSERT SQL template to prevent reserved-keyword syntax errors.
"""

def test_all_column_names_are_quoted_in_insert_sql(self, eth_tx_batch):
"""Every column in the INSERT template must be wrapped in double quotes."""
sql_template, _ = prepare_insert_data(eth_tx_batch)

for col_name in eth_tx_batch.schema.names:
assert f'"{col_name}"' in sql_template, (
f"Column '{col_name}' must be double-quoted in the INSERT SQL template.\n"
f'Generated template: {sql_template}'
)

def test_placeholder_count_matches_column_count(self, eth_tx_batch):
"""The VALUES clause must have exactly one %s placeholder per column."""
sql_template, _ = prepare_insert_data(eth_tx_batch)

assert sql_template.count('%s') == len(eth_tx_batch.schema)

def test_row_count_preserved(self, eth_tx_batch):
"""The returned rows list must contain one tuple per input row."""
_, rows = prepare_insert_data(eth_tx_batch)

assert len(rows) == eth_tx_batch.num_rows
Loading