diff --git a/sqlglot/expressions.py b/sqlglot/expressions.py index 7f3a4a8e3a..b98fd604d2 100644 --- a/sqlglot/expressions.py +++ b/sqlglot/expressions.py @@ -2218,7 +2218,7 @@ class ColumnPrefix(Expression): class PrimaryKey(Expression): - arg_types = {"expressions": True, "options": False} + arg_types = {"expressions": True, "options": False, "include": False} # https://www.postgresql.org/docs/9.1/sql-selectinto.html diff --git a/sqlglot/generator.py b/sqlglot/generator.py index 4431ece962..5ffd3ab6ba 100644 --- a/sqlglot/generator.py +++ b/sqlglot/generator.py @@ -2995,11 +2995,13 @@ def foreignkey_sql(self, expression: exp.ForeignKey) -> str: options = f" {options}" if options else "" return f"FOREIGN KEY{expressions}{reference}{delete}{update}{options}" - def primarykey_sql(self, expression: exp.ForeignKey) -> str: + def primarykey_sql(self, expression: exp.PrimaryKey) -> str: expressions = self.expressions(expression, flat=True) + include = self.expressions(expression, key="include", flat=True) + include = f" INCLUDE ({include})" if include else "" options = self.expressions(expression, key="options", flat=True, sep=" ") options = f" {options}" if options else "" - return f"PRIMARY KEY ({expressions}){options}" + return f"PRIMARY KEY ({expressions}){include}{options}" def if_sql(self, expression: exp.If) -> str: return self.case_sql(exp.Case(ifs=[expression], default=expression.args.get("false"))) diff --git a/sqlglot/parser.py b/sqlglot/parser.py index 910461a6e8..8f928554e8 100644 --- a/sqlglot/parser.py +++ b/sqlglot/parser.py @@ -6296,8 +6296,16 @@ def _parse_primary_key( expressions = self._parse_wrapped_csv( self._parse_primary_key_part, optional=wrapped_optional ) + + index_params = None + if self._match_text_seq("INCLUDE", advance=False): + index_params = self._parse_index_params() + include = index_params.args.get("include") if index_params else None + options = self._parse_key_constraint_options() - return self.expression(exp.PrimaryKey, expressions=expressions, options=options) + return self.expression( + exp.PrimaryKey, expressions=expressions, include=include, options=options + ) def _parse_bracket_key_value(self, is_map: bool = False) -> t.Optional[exp.Expression]: return self._parse_slice(self._parse_alias(self._parse_assignment(), explicit=True)) diff --git a/tests/dialects/test_postgres.py b/tests/dialects/test_postgres.py index 3eb497419d..f1e70b1d31 100644 --- a/tests/dialects/test_postgres.py +++ b/tests/dialects/test_postgres.py @@ -987,6 +987,7 @@ def test_ddl(self): self.validate_identity( "CREATE TABLE t (vid INT NOT NULL, CONSTRAINT ht_vid_nid_fid_idx EXCLUDE (INT4RANGE(vid, nid) WITH &&, INT4RANGE(fid, fid, '[]') WITH &&))" ) + self.validate_identity("CREATE TABLE t (i INT, a TEXT, PRIMARY KEY (i) INCLUDE (a))") self.validate_identity( "CREATE TABLE t (i INT, PRIMARY KEY (i), EXCLUDE USING gist(col varchar_pattern_ops DESC NULLS LAST WITH &&) WITH (sp1=1, sp2=2))" )