diff --git a/src/amp/loaders/implementations/_postgres_helpers.py b/src/amp/loaders/implementations/_postgres_helpers.py index eb0f71e..8b47635 100644 --- a/src/amp/loaders/implementations/_postgres_helpers.py +++ b/src/amp/loaders/implementations/_postgres_helpers.py @@ -7,6 +7,31 @@ from pyarrow import csv +def _quote_identifier(name: str) -> str: + """Return a safely double-quoted PostgreSQL identifier. + + Wraps the name in double quotes and escapes any embedded double quotes + by doubling them (standard SQL identifier quoting rules). This prevents + syntax errors when column or table names collide with PostgreSQL reserved + keywords such as `to`, `end`, `index`, etc. + + Args: + name: Raw identifier (column or table name) + + Returns: + Double-quoted identifier safe for use in SQL DDL/DML statements + + Examples: + >>> _quote_identifier('to') + '"to"' + >>> _quote_identifier('block_num') + '"block_num"' + >>> _quote_identifier('weird"name') + '"weird\"\"name"' + """ + return '"' + name.replace('"', '""') + '"' + + def prepare_csv_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[io.StringIO, List[str]]: """ Convert Arrow data to CSV format optimized for PostgreSQL COPY. @@ -43,7 +68,8 @@ def prepare_csv_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[io.StringIO csv_buffer = io.StringIO(csv_data) - # Get column names from Arrow schema + # Get column names from Arrow schema (raw, unquoted — quoting is the + # caller's responsibility when constructing SQL identifiers) column_names = [field.name for field in data.schema] return csv_buffer, column_names @@ -103,12 +129,14 @@ def prepare_insert_data(data: Union[pa.RecordBatch, pa.Table]) -> Tuple[str, Lis # Convert Arrow data to Python objects data_dict = data.to_pydict() - # Get column names + # Get column names and quote each one so reserved keywords (e.g. `to`, + # `end`, `index`) do not cause syntax errors in the INSERT statement. column_names = [field.name for field in data.schema] + quoted_columns = [_quote_identifier(c) for c in column_names] # Create INSERT statement template placeholders = ', '.join(['%s'] * len(column_names)) - insert_sql = f'({", ".join(column_names)}) VALUES ({placeholders})' + insert_sql = f'({", ".join(quoted_columns)}) VALUES ({placeholders})' # Prepare data for insertion rows = [] diff --git a/tests/unit/test_postgres_helpers.py b/tests/unit/test_postgres_helpers.py new file mode 100644 index 0000000..7b27c33 --- /dev/null +++ b/tests/unit/test_postgres_helpers.py @@ -0,0 +1,71 @@ +""" +Unit tests for PostgreSQL SQL identifier quoting in _postgres_helpers.py. + +Verifies that reserved keyword column names (e.g. 'to', 'from') are double-quoted in +generated INSERT and COPY SQL to prevent syntax errors. +""" + +import pyarrow as pa +import pytest + +from amp.loaders.implementations._postgres_helpers import prepare_insert_data + + +@pytest.fixture +def eth_tx_batch(): + """Arrow RecordBatch modelling Ethereum transaction data with reserved keyword column names.""" + schema = pa.schema( + [ + pa.field('block_hash', pa.binary(32), nullable=False), + pa.field('block_num', pa.uint64(), nullable=False), + pa.field('tx_index', pa.uint32(), nullable=False), + pa.field('tx_hash', pa.binary(32), nullable=False), + pa.field('to', pa.binary(20), nullable=True), # reserved keyword; nullable for contract creation + pa.field('nonce', pa.uint64(), nullable=False), + pa.field('value', pa.decimal128(38, 0), nullable=False), + pa.field('from', pa.binary(20), nullable=False), # reserved keyword + ] + ) + return pa.RecordBatch.from_arrays( + [ + pa.array([b'\x01' * 32, b'\x02' * 32], type=pa.binary(32)), + pa.array([18_000_000, 18_000_001], type=pa.uint64()), + pa.array([0, 1], type=pa.uint32()), + pa.array([b'\x03' * 32, b'\x04' * 32], type=pa.binary(32)), + pa.array([b'\xaa' * 20, None], type=pa.binary(20)), # None = contract creation tx + pa.array([0, 1], type=pa.uint64()), + pa.array([1_000_000_000, 2_000_000_000], type=pa.decimal128(38, 0)), + pa.array([b'\xbb' * 20, b'\xcc' * 20], type=pa.binary(20)), + ], + schema=schema, + ) + + +@pytest.mark.unit +class TestInsertSqlIdentifierQuoting: + """ + prepare_insert_data() must double-quote all column names in the generated + INSERT SQL template to prevent reserved-keyword syntax errors. + """ + + def test_all_column_names_are_quoted_in_insert_sql(self, eth_tx_batch): + """Every column in the INSERT template must be wrapped in double quotes.""" + sql_template, _ = prepare_insert_data(eth_tx_batch) + + for col_name in eth_tx_batch.schema.names: + assert f'"{col_name}"' in sql_template, ( + f"Column '{col_name}' must be double-quoted in the INSERT SQL template.\n" + f'Generated template: {sql_template}' + ) + + def test_placeholder_count_matches_column_count(self, eth_tx_batch): + """The VALUES clause must have exactly one %s placeholder per column.""" + sql_template, _ = prepare_insert_data(eth_tx_batch) + + assert sql_template.count('%s') == len(eth_tx_batch.schema) + + def test_row_count_preserved(self, eth_tx_batch): + """The returned rows list must contain one tuple per input row.""" + _, rows = prepare_insert_data(eth_tx_batch) + + assert len(rows) == eth_tx_batch.num_rows