|
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 |
|
15 | | -from google.cloud.spanner_v1 import TransactionOptions, ResultSetMetadata |
| 15 | +from google.cloud.spanner_v1 import ( |
| 16 | + TransactionOptions, |
| 17 | + ResultSetMetadata, |
| 18 | + ExecuteSqlRequest, |
| 19 | +) |
16 | 20 | from google.protobuf import empty_pb2 |
17 | 21 | import test.mockserver_tests.spanner_pb2_grpc as spanner_grpc |
18 | 22 | import test.mockserver_tests.spanner_database_admin_pb2_grpc as database_admin_grpc |
@@ -40,23 +44,25 @@ def get_result(self, sql: str) -> result_set.ResultSet: |
40 | 44 | return result |
41 | 45 |
|
42 | 46 | def get_result_as_partial_result_sets( |
43 | | - self, sql: str |
| 47 | + self, sql: str, started_transaction: transaction.Transaction |
44 | 48 | ) -> [result_set.PartialResultSet]: |
45 | 49 | result: result_set.ResultSet = self.get_result(sql) |
46 | 50 | partials = [] |
47 | 51 | first = True |
48 | 52 | if len(result.rows) == 0: |
49 | 53 | partial = result_set.PartialResultSet() |
50 | | - partial.metadata = result.metadata |
| 54 | + partial.metadata = ResultSetMetadata(result.metadata) |
51 | 55 | partials.append(partial) |
52 | 56 | else: |
53 | 57 | for row in result.rows: |
54 | 58 | partial = result_set.PartialResultSet() |
55 | 59 | if first: |
56 | | - partial.metadata = result.metadata |
| 60 | + partial.metadata = ResultSetMetadata(result.metadata) |
57 | 61 | partial.values.extend(row) |
58 | 62 | partials.append(partial) |
59 | 63 | partials[len(partials) - 1].stats = result.stats |
| 64 | + if started_transaction: |
| 65 | + partials[0].metadata.transaction = started_transaction |
60 | 66 | return partials |
61 | 67 |
|
62 | 68 |
|
@@ -120,9 +126,16 @@ def ExecuteSql(self, request, context): |
120 | 126 | self._requests.append(request) |
121 | 127 | return result_set.ResultSet() |
122 | 128 |
|
123 | | - def ExecuteStreamingSql(self, request, context): |
| 129 | + def ExecuteStreamingSql(self, request: ExecuteSqlRequest, context): |
124 | 130 | self._requests.append(request) |
125 | | - partials = self.mock_spanner.get_result_as_partial_result_sets(request.sql) |
| 131 | + started_transaction = None |
| 132 | + if not request.transaction.begin == TransactionOptions(): |
| 133 | + started_transaction = self.__create_transaction( |
| 134 | + request.session, request.transaction.begin |
| 135 | + ) |
| 136 | + partials = self.mock_spanner.get_result_as_partial_result_sets( |
| 137 | + request.sql, started_transaction |
| 138 | + ) |
126 | 139 | for result in partials: |
127 | 140 | yield result |
128 | 141 |
|
|
0 commit comments