Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 20 additions & 26 deletions sqlparse/engine/statement_splitter.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@ def __init__(self):

def _reset(self):
"""Set the filter attributes to its default values"""
self._in_declare = False
self._in_case = False
self._is_create = False
self._case_depth = 0
self._stmt_start = True
self._begin_depth = 0

self.consume_ws = False
Expand All @@ -33,6 +32,10 @@ def _change_splitlevel(self, ttype, value):
return 1
elif ttype is T.Punctuation and value == ')':
return -1
elif ttype is T.Punctuation and value == ';' and self._stmt_start:
self._begin_depth = max(0, self._begin_depth - 1)
self._begin_depth -= 1
return -1
elif ttype not in T.Keyword: # if normal token return
return 0

Expand All @@ -41,40 +44,30 @@ def _change_splitlevel(self, ttype, value):
# returning
unified = value.upper()

# three keywords begin with CREATE, but only one of them is DDL
# DDL Create though can contain more words such as "or replace"
if ttype is T.Keyword.DDL and unified.startswith('CREATE'):
self._is_create = True
return 0

# can have nested declare inside of being...
if unified == 'DECLARE' and self._is_create and self._begin_depth == 0:
self._in_declare = True
if unified == 'BEGIN' and not self._stmt_start:
self._begin_depth += 1
return 1

if unified == 'BEGIN':
self._begin_depth += 1
if self._is_create:
# FIXME(andi): This makes no sense. ## this comment neither
return 1
return 0
if self._stmt_start:
self._stmt_start = False

# BEGIN and CASE/WHEN both end with END
if unified == 'END':
if not self._in_case:
if not self._case_depth:
self._begin_depth = max(0, self._begin_depth - 1)
else:
self._in_case = False
self._case_depth = max(0, self._case_depth - 1)
return -1

if (unified in ('IF', 'FOR', 'WHILE', 'CASE')
and self._is_create and self._begin_depth > 0):
if unified in ('IF', 'FOR', 'WHILE', 'CASE'):
if unified == 'CASE':
self._in_case = True
return 1
self._case_depth += 1
if self._begin_depth > 0:
return 1

if unified in ('END IF', 'END FOR', 'END WHILE'):
return -1
if unified in ('END IF', 'END FOR', 'END WHILE', 'END CASE'):
if self._begin_depth > 0:
return -1

# Default
return 0
Expand All @@ -84,6 +77,7 @@ def process(self, stream):
EOS_TTYPE = T.Whitespace, T.Comment.Single

# Run over all stream tokens
sb = ""
for ttype, value in stream:
# Yield token if we finished a statement and there's no whitespaces
# It will count newline token as a non whitespace. In this context
Expand Down
2 changes: 1 addition & 1 deletion sqlparse/keywords.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@
(r'(?<![\w\])])(\[[^\]\[]+\])', tokens.Name),
(r'((LEFT\s+|RIGHT\s+|FULL\s+)?(INNER\s+|OUTER\s+|STRAIGHT\s+)?'
r'|(CROSS\s+|NATURAL\s+)?)?JOIN\b', tokens.Keyword),
(r'END(\s+IF|\s+LOOP|\s+WHILE)?\b', tokens.Keyword),
(r'END(\s+IF|\s+LOOP|\s+WHILE|\s+CASE)?\b', tokens.Keyword),
(r'NOT\s+NULL\b', tokens.Keyword),
(r'(ASC|DESC)(\s+NULLS\s+(FIRST|LAST))?\b', tokens.Keyword.Order),
(r'(ASC|DESC)\b', tokens.Keyword.Order),
Expand Down
8 changes: 8 additions & 0 deletions tests/files/begincommit_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
-- postgresql, mysql
begin
/*comment*/
--comment
;
update foo
set bar = 1;
commit;
14 changes: 14 additions & 0 deletions tests/files/begincommit_2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
-- SQL SERVER
BEGIN TRAN T1;

UPDATE table1 ...;

BEGIN TRAN M2 WITH MARK;
UPDATE table2 ...;
SELECT * from table1;

COMMIT TRAN M2;

UPDATE table3 ...;

COMMIT TRAN T1;
2 changes: 1 addition & 1 deletion tests/files/begintag_2.sql → tests/files/beginend_1.sql
Original file line number Diff line number Diff line change
Expand Up @@ -10,4 +10,4 @@ CREATE TRIGGER IF NOT EXISTS remove_if_it_was_the_last_file_link
BEGIN
DELETE FROM dir_entries
WHERE dir_entries.inode = OLD.child_entry;
END;
END;
9 changes: 9 additions & 0 deletions tests/files/beginend_2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
--Trino
WITH FUNCTION meaning_of_life()
RETURNS tinyint
BEGIN
DECLARE a tinyint DEFAULT CAST(6 as tinyint);
DECLARE b tinyint DEFAULT CAST(7 as tinyint);
RETURN a * b;
END
SELECT meaning_of_life();
9 changes: 9 additions & 0 deletions tests/files/beginend_3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
--Bigquery
CREATE OR REPLACE PROCEDURE mydataset.create_customer()
BEGIN
DECLARE id STRING;
SET id = GENERATE_UUID();
INSERT INTO mydataset.customers (customer_id)
VALUES(id);
SELECT FORMAT("Created customer %s", id);
END;
17 changes: 17 additions & 0 deletions tests/files/beginend_with_double_case_when_1.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
--Trino
WITH FUNCTION double_case(a bigint)
RETURNS varchar
BEGIN
CASE abs(a)
WHEN 0 THEN RETURN 'zero';
WHEN 1 THEN
CASE
WHEN a < 0 THEN RETURN 'minus one';
ELSE RETURN 'one';
END CASE;
ELSE RETURN 'other';
END CASE;
RETURN null;
END
SELECT double_case(0);
SELECT 1 ;
17 changes: 17 additions & 0 deletions tests/files/beginend_with_double_case_when_2.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
--Trino
WITH FUNCTION double_case(a bigint)
RETURNS varchar
BEGIN
RETURN
CASE abs(a)
WHEN 0 THEN 'zero'
WHEN 1 THEN
CASE
WHEN a < 0 THEN 'minus one'
ELSE 'one'
END
ELSE 'other'
END;
END
SELECT double_case(0);
SELECT 1 ;
21 changes: 21 additions & 0 deletions tests/files/beginend_with_double_case_when_3.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
--Trino
WITH FUNCTION double_case(a bigint)
RETURNS varchar
BEGIN
CASE
WHEN a IS NULL THEN RETURN 'null';
ELSE RETURN
CASE abs(a)
WHEN 0 THEN 'zero'
WHEN 1 THEN
CASE
WHEN a < 0 THEN 'minus one'
ELSE 'one'
END
ELSE 'other'
END;
END CASE;
RETURN null;
END
SELECT double_case(0);
SELECT 1 ;
4 changes: 0 additions & 4 deletions tests/files/begintag.sql

This file was deleted.

3 changes: 1 addition & 2 deletions tests/files/function.sql
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,8 @@ CREATE OR REPLACE FUNCTION foo(
, p_in2 INTEGER
) RETURNS INTEGER AS

DECLARE
v_foo INTEGER;
BEGIN
DECLARE v_foo INTEGER;
SELECT *
FROM foo
INTO v_foo;
Expand Down
4 changes: 2 additions & 2 deletions tests/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,8 @@ def test_invalid_outfile(filepath, capsys):


def test_stdout(filepath, load_file, capsys):
path = filepath('begintag.sql')
expected = load_file('begintag.sql')
path = filepath('begincommit_1.sql')
expected = load_file('begincommit_1.sql')
sqlparse.cli.main([path])
out, _ = capsys.readouterr()
assert out == expected
Expand Down
58 changes: 52 additions & 6 deletions tests/test_split.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,7 @@ def test_split_backslash():
assert len(stmts) == 2


@pytest.mark.parametrize('fn', ['function.sql',
'function_psql.sql',
@pytest.mark.parametrize('fn', ['function_psql.sql',
'function_psql2.sql',
'function_psql3.sql',
'function_psql4.sql'])
Expand All @@ -33,6 +32,11 @@ def test_split_create_function(load_file, fn):
assert len(stmts) == 1
assert str(stmts[0]) == sql

# Check that the parser doesn't get in an incorrect state after the first split
stmts = sqlparse.parse(sql + sql)
assert len(stmts) == 2
assert str(stmts[0]) == sql


def test_split_dashcomments(load_file):
sql = load_file('dashcomment.sql')
Expand All @@ -50,20 +54,62 @@ def test_split_dashcomments_eol(s):
assert len(stmts) == 1


def test_split_begintag(load_file):
sql = load_file('begintag.sql')
def test_split_begincommit_1(load_file):
sql = load_file('begincommit_1.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 3
assert ''.join(str(q) for q in stmts) == sql


def test_split_begintag_2(load_file):
sql = load_file('begintag_2.sql')
def test_split_begincommit_2(load_file):
sql = load_file('begincommit_2.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 8
assert ''.join(str(q) for q in stmts) == sql


def test_split_beginend_1(load_file):
sql = load_file('beginend_1.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 1
assert ''.join(str(q) for q in stmts) == sql


def test_split_beginend_2(load_file):
sql = load_file('beginend_2.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 1
assert ''.join(str(q) for q in stmts) == sql


def test_split_beginend_3(load_file):
sql = load_file('beginend_3.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 1
assert ''.join(str(q) for q in stmts) == sql


def test_split_beginend_with_double_case_when_1(load_file):
sql = load_file('beginend_with_double_case_when_1.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 2
assert ''.join(str(q) for q in stmts) == sql


def test_split_beginend_with_double_case_when_2(load_file):
sql = load_file('beginend_with_double_case_when_2.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 2
assert ''.join(str(q) for q in stmts) == sql


def test_split_beginend_with_double_case_when_3(load_file):
sql = load_file('beginend_with_double_case_when_3.sql')
stmts = sqlparse.parse(sql)
assert len(stmts) == 2
assert ''.join(str(q) for q in stmts) == sql


def test_split_dropif():
sql = 'DROP TABLE IF EXISTS FOO;\n\nSELECT * FROM BAR;'
stmts = sqlparse.parse(sql)
Expand Down