Skip to content

Commit 210c1e9

Browse files
authored
fix default_paramstyle (#27)
* fix default_paramstyle * fix validator * support lastrowid * support null value * fix type hint
1 parent b2ea173 commit 210c1e9

File tree

4 files changed

+22
-15
lines changed

4 files changed

+22
-15
lines changed

pydataapi/dbapi.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,11 +112,16 @@ def __init__(self, data_api: DataAPI) -> None:
112112

113113
self._rows: List[List] = []
114114
self._rowcount: int = -1
115+
self._lastrowid: Optional[int] = None
115116

116117
@property
117118
def rowcount(self) -> int:
118119
return self._rowcount
119120

121+
@property
122+
def lastrowid(self) -> Optional[int]:
123+
return self._lastrowid
124+
120125
def close(self) -> None:
121126
self.closed = True
122127

@@ -131,6 +136,7 @@ def execute(
131136
rows: List[List] = getattr(result, '_rows')
132137
self._rows = rows
133138
self._rowcount = len(rows) or result.number_of_records_updated
139+
self._lastrowid = result.generated_fields_first # type: ignore
134140
return self
135141

136142
def executemany(
@@ -141,6 +147,9 @@ def executemany(
141147
self._rows = [result.generated_fields for result in results]
142148
self._rowcount = len(self._rows)
143149
self.description = []
150+
self._lastrowid = ( # type: ignore
151+
results[-1].generated_fields_first if results else None # type: ignore
152+
)
144153
return self
145154

146155
def fetchone(self) -> Optional[List]:

pydataapi/dialect.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,6 @@ class DataAPIDialect(DefaultDialect, ABC):
3737

3838
supports_comments = True
3939
inline_comments = True
40-
default_paramstyle = "named"
4140

4241
cte_follows_insert = True
4342

@@ -168,21 +167,13 @@ def _detect_charset(self, connection: Any) -> Any: # pragma: no cover
168167
pass
169168

170169
name = "mysql"
171-
statement_compiler = MySQLCompiler
172-
ddl_compiler = MySQLDDLCompiler
173-
type_compiler = MySQLTypeCompiler
174-
175-
preparer = MySQLIdentifierPreparer
170+
default_paramstyle = "named"
176171

177172

178173
class PostgreSQLDataAPIDialect(PGDialect, DataAPIDialect):
179174
name = "postgresql"
175+
default_paramstyle = "named"
180176
supports_alter = True
181177
max_identifier_length = 63
182178
supports_sane_rowcount = True
183-
statement_compiler = PGCompiler
184-
ddl_compiler = PGDDLCompiler
185-
type_compiler = PGTypeCompiler
186-
preparer = PGIdentifierPreparer
187-
inspector = PGInspector
188179
isolation_level = None

pydataapi/pydataapi.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -164,8 +164,13 @@ def __len__(self) -> int:
164164

165165
def __init__(self, response: Dict):
166166
self._response = response
167-
self._rows: Sequence[List[Dict]] = [
168-
[tuple(column.values())[0] for column in row]
167+
self._rows: Sequence[List] = [
168+
[
169+
None
170+
if tuple(column.keys())[0] == 'isNull'
171+
else tuple(column.values())[0]
172+
for column in row
173+
]
169174
for row in response.get('records', []) # type: ignore
170175
]
171176
self._column_metadata: List[Dict[str, Any]] = response.get('columnMetadata', [])
@@ -251,7 +256,7 @@ def convert_parameters(cls, v: Any) -> Any:
251256

252257
@validator('parameterSets', pre=True)
253258
def convert_parameter_sets(cls, v: Any) -> Any:
254-
if isinstance(v, list):
259+
if isinstance(v, (list, tuple)):
255260
return [create_sql_parameters(parameter) for parameter in v]
256261
return v
257262

tests/pydataapi/test_dbaapi.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,15 @@ def test_rollback_not_called(mocked_client) -> None:
5858
def test_execute_insert(mocked_client, mocker) -> None:
5959
mocked_client.begin_transaction.return_value = {'transactionId': 'abc'}
6060
mocked_client.execute_statement.return_value = {
61-
'generatedFields': [],
61+
'generatedFields': [{'longValue': 3}],
6262
'numberOfRecordsUpdated': 1,
6363
}
6464
data_api = connect(
6565
resource_arn='dummy', secret_arn='dummy', database='test', client=mocked_client
6666
)
6767
results = data_api.execute("insert into pets values(1, 'cat')")
6868
assert list(results.fetchall()) == []
69+
assert results.lastrowid == 3
6970
assert mocked_client.execute_statement.call_args == mocker.call(
7071
continueAfterTimeout=True,
7172
includeResultMetadata=True,
@@ -215,6 +216,7 @@ def test_execute_insert_parameter_set(mocked_client, mocker) -> None:
215216
rows = results.fetchall()
216217
assert len(rows) == 2
217218
assert rows == [[3], [4]]
219+
assert results.lastrowid == 4
218220

219221
assert mocked_client.batch_execute_statement.call_args == mocker.call(
220222
resourceArn='dummy',

0 commit comments

Comments
 (0)