|
46 | 46 | import ckan.plugins as plugins |
47 | 47 | from ckan.common import CKANConfig, config |
48 | 48 |
|
| 49 | +from ckanext.datastore.filters import parse_query_filters, FilterOp |
49 | 50 | from ckanext.datastore.backend import ( |
50 | 51 | DatastoreBackend, |
51 | 52 | DatastoreException, |
|
85 | 86 | _UPSERT = 'upsert' |
86 | 87 | _UPDATE = 'update' |
87 | 88 |
|
88 | | -_OPERATORS = { |
89 | | - 'eq': '=', |
90 | | - 'gt': '>', |
91 | | - 'gte': '>=', |
92 | | - 'lt': '<', |
93 | | - 'lte': '<=', |
94 | | -} |
95 | | - |
96 | 89 | if not os.environ.get('DATASTORE_LOAD'): |
97 | 90 | ValidationError = toolkit.ValidationError # type: ignore |
98 | 91 | else: |
@@ -427,41 +420,47 @@ def _where_clauses( |
427 | 420 | clauses: WhereClauses = [] |
428 | 421 |
|
429 | 422 | idx_gen = itertools.count() |
| 423 | + placeholders = {} |
| 424 | + |
| 425 | + def placeholder(f: str, v: Any) -> str: |
| 426 | + if fields_types[f] == 'text': |
| 427 | + # pSQL can do int_field = "10" |
| 428 | + # but cannot do text_field = 10 |
| 429 | + # this fixes parity there. |
| 430 | + v = str(v) |
| 431 | + p = f"value_{next(idx_gen)}" |
| 432 | + placeholders[p] = v |
| 433 | + return p |
| 434 | + |
| 435 | + def build_clause(fo: FilterOp) -> str: |
| 436 | + '''recursively build clause and placeholders dict''' |
| 437 | + match fo: |
| 438 | + case FilterOp(op='$and', value=v): |
| 439 | + c = (build_clause(f) for f in v) |
| 440 | + return f'({" AND ".join(c)})' if v else 'true' |
| 441 | + case FilterOp(op='$or', value=v): |
| 442 | + c = (build_clause(f) for f in v) |
| 443 | + return f'({" OR ".join(c)})' if v else 'false' |
| 444 | + case FilterOp(field=f, op='eq', value=v): |
| 445 | + return f'{identifier(f)} = :{placeholder(f, v)}' |
| 446 | + case FilterOp(field=f, op='in', value=v): |
| 447 | + ph = (placeholder(f, each) for each in v) |
| 448 | + return f'{identifier(f)} in ({",".join(ph)})' if v else 'false' |
| 449 | + case FilterOp(field=f, op='gt', value=v): |
| 450 | + return f'{identifier(f)} > :{placeholder(f, v)}' |
| 451 | + case FilterOp(field=f, op='gte', value=v): |
| 452 | + return f'{identifier(f)} >= :{placeholder(f, v)}' |
| 453 | + case FilterOp(field=f, op='lt', value=v): |
| 454 | + return f'{identifier(f)} < :{placeholder(f, v)}' |
| 455 | + case FilterOp(field=f, op='lte', value=v): |
| 456 | + return f'{identifier(f)} <= :{placeholder(f, v)}' |
| 457 | + case FilterOp(op=o): |
| 458 | + raise ValidationError( |
| 459 | + {"filters": [f"Unknown filter operation: {o!r}"]} |
| 460 | + ) |
430 | 461 |
|
431 | | - for field, value in filters.items(): |
432 | | - if field not in fields_types: |
433 | | - continue |
434 | | - field_array_type = _is_array_type(fields_types[field]) |
435 | | - |
436 | | - if isinstance(value, list) and not field_array_type: |
437 | | - placeholders = [ |
438 | | - f"value_{next(idx_gen)}" for _ in value |
439 | | - ] |
440 | | - clause_str = ('{0} in ({1})'.format( |
441 | | - sa.column(field), |
442 | | - ','.join(f":{p}" for p in placeholders) |
443 | | - )) |
444 | | - if fields_types[field] == 'text': |
445 | | - # pSQL can do int_field = "10" |
446 | | - # but cannot do text_field = 10 |
447 | | - # this fixes parity there. |
448 | | - value = (str(v) for v in value) |
449 | | - clause = (clause_str, dict(zip(placeholders, value))) |
450 | | - else: |
451 | | - operator = '=' |
452 | | - if isinstance(value, dict): |
453 | | - operator, value = _prepare_where_operator_and_value(value) |
454 | | - if fields_types[field] == 'text': |
455 | | - # pSQL can do int_field = "10" |
456 | | - # but cannot do text_field = 10 |
457 | | - # this fixes parity there. |
458 | | - value = str(value) |
459 | | - placeholder = f"value_{next(idx_gen)}" |
460 | | - clause: tuple[Any, ...] = ( |
461 | | - f'{sa.column(field)} {operator} :{placeholder}', |
462 | | - {placeholder: value} |
463 | | - ) |
464 | | - clauses.append(clause) |
| 462 | + fltr = parse_query_filters(filters, {"fields": fields_types}) |
| 463 | + clauses.append((build_clause(fltr), placeholders)) |
465 | 464 |
|
466 | 465 | # add full-text search where clause |
467 | 466 | q: Union[dict[str, str], str, Any] = data_dict.get('q') |
@@ -498,18 +497,6 @@ def _where_clauses( |
498 | 497 | return clauses |
499 | 498 |
|
500 | 499 |
|
501 | | -def _prepare_where_operator_and_value(value: dict[str, Any]) -> tuple[str, Any]: |
502 | | - try: |
503 | | - [(key, val)] = value.items() |
504 | | - except ValueError: |
505 | | - return '=', value |
506 | | - |
507 | | - try: |
508 | | - return _OPERATORS[key], val |
509 | | - except KeyError: |
510 | | - return '=', value |
511 | | - |
512 | | - |
513 | 500 | def _textsearch_query( |
514 | 501 | lang: str, q: Optional[Union[str, dict[str, str], Any]], plain: bool, |
515 | 502 | full_text: Optional[str]) -> tuple[str, dict[str, str]]: |
@@ -1505,8 +1492,10 @@ def search_data(context: Context, data_dict: dict[str, Any]): |
1505 | 1492 | else: |
1506 | 1493 | operator = 'gt' |
1507 | 1494 | last_id_select = 'max(_id)' |
1508 | | - is_keyset = any(i[0].startswith( |
1509 | | - f'_id {_OPERATORS[operator]} ') for i in query_dict['where']) |
| 1495 | + is_keyset = any( |
| 1496 | + i[0].startswith( f'_id > ') or i[0].startswith('_id < ') |
| 1497 | + for i in query_dict['where'] |
| 1498 | + ) |
1510 | 1499 |
|
1511 | 1500 | if is_keyset: |
1512 | 1501 | final_statement = '{where} {sort} LIMIT {limit}' |
|
0 commit comments