Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 85 additions & 87 deletions sqle/pkg/postgresql/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -104,110 +104,108 @@ func (o *DB) ShowSchemaViews(schema string) ([]string, error) {
return getResultSqls(o.Db, query)
}

func (o *DB) ShowCreateTables(database, tableName string, schemas []string) ([]string, error) {
func (o *DB) ShowCreateTables(database, schema, tableName string) ([]string, error) {
tables := make([]string, 0)
for _, schema := range schemas {
tableDDl := fmt.Sprintf("CREATE TABLE %s.%s(", schema, tableName)
if o.IsCaseSensitive {
database = strings.ToLower(database)
schema = strings.ToLower(schema)
tableName = strings.ToLower(tableName)
}
columnsCondition := fmt.Sprintf("table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s'",
database, schema, tableName)
if o.IsCaseSensitive {
columnsCondition = fmt.Sprintf("lower(table_catalog) = '%s' AND lower(table_schema) = '%s' "+
"AND lower(table_name) = '%s'", database, schema, tableName)
}
// 获取列定义,多个英文逗号分割
columns := fmt.Sprintf("SELECT string_agg(column_name || ' ' || "+
"CASE "+
" WHEN data_type IN ('character', 'character varying', 'text') "+
" THEN data_type || '(' || character_maximum_length || ')' "+
" WHEN data_type IN ('numeric', 'decimal') "+
" THEN data_type || '(' || numeric_precision || ',' || numeric_scale || ')' "+
" WHEN data_type IN ('integer', 'smallint', 'bigint') THEN data_type "+
" ELSE data_type "+
" END "+
" || "+
" CASE "+
" WHEN column_default != '' THEN ' DEFAULT ' || column_default ELSE '' END "+
" || "+
" CASE "+
" WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END, ',\n ' ORDER BY ordinal_position) AS columns_sql"+
" FROM information_schema.columns "+
" WHERE %s GROUP BY table_name", columnsCondition)
sqls, err := getResultSqls(o.Db, columns)
if err != nil {
log.Printf("search column definition error:%s\n", err)
return nil, err
}
if len(sqls) == 0 {
tableDDl := fmt.Sprintf("CREATE TABLE %s.%s(", schema, tableName)
if o.IsCaseSensitive {
database = strings.ToLower(database)
schema = strings.ToLower(schema)
tableName = strings.ToLower(tableName)
}
columnsCondition := fmt.Sprintf("table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s'",
database, schema, tableName)
if o.IsCaseSensitive {
columnsCondition = fmt.Sprintf("lower(table_catalog) = '%s' AND lower(table_schema) = '%s' "+
"AND lower(table_name) = '%s'", database, schema, tableName)
}
// 获取列定义,多个英文逗号分割
columns := fmt.Sprintf("SELECT string_agg(column_name || ' ' || "+
"CASE "+
" WHEN data_type IN ('char', 'varchar', 'character', 'character varying', 'text') "+
" THEN data_type || '(' || COALESCE(character_maximum_length, 0) || ')' "+
" WHEN data_type IN ('numeric', 'decimal') "+
" THEN data_type || '(' || COALESCE(numeric_precision, 0) || ',' || COALESCE(numeric_scale, 0) || ')' "+
" WHEN data_type IN ('integer', 'smallint', 'bigint') THEN data_type "+
" ELSE data_type "+
" END "+
" || "+
" CASE "+
" WHEN column_default != '' THEN ' DEFAULT ' || column_default ELSE '' END "+
" || "+
" CASE "+
" WHEN is_nullable = 'NO' THEN ' NOT NULL' ELSE '' END, ',\n ' ORDER BY ordinal_position) AS columns_sql"+
" FROM information_schema.columns "+
" WHERE %s GROUP BY table_name", columnsCondition)
sqls, err := getResultSqls(o.Db, columns)
if err != nil {
log.Printf("search column definition error:%s\n", err)
return nil, err
}
if len(sqls) == 0 {
return tables, nil
}
tableDDl += strings.Join(sqls, "")
constraintsCondition := fmt.Sprintf("n.nspname = '%s' AND C.relname = '%s'", schema, tableName)
if o.IsCaseSensitive {
constraintsCondition = fmt.Sprintf("lower(n.nspname) = '%s' "+
"AND lower(C.relname) = '%s'", schema, tableName)
}
// 获取所有约束
constraints := fmt.Sprintf("SELECT 'CONSTRAINT ' || r.conname || ' ' || "+
" pg_catalog.pg_get_constraintdef ( r.OID, TRUE ) AS constraint_definition "+
" FROM pg_catalog.pg_constraint r "+
" JOIN pg_catalog.pg_class C ON C.OID = r.conrelid "+
" JOIN pg_catalog.pg_namespace n ON n.OID = C.relnamespace "+
" WHERE %s", constraintsCondition)
sqls, err = getResultSqls(o.Db, constraints)
if err != nil {
log.Printf("search constraint definition error:%s\n", err)
return nil, err
}
for _, sqlContext := range sqls {
tableDDl += ",\n" + sqlContext
}
tableDDl += ")"
indexesCondition := fmt.Sprintf("schemaname = '%s' and tablename = '%s' ", schema, tableName)
if o.IsCaseSensitive {
indexesCondition = fmt.Sprintf("lower(schemaname) = '%s' and lower(tablename) = '%s'",
schema, tableName)
}
// 获取索引
indexes := fmt.Sprintf("SELECT indexdef AS index_definition FROM pg_indexes "+
" WHERE %s", indexesCondition)
sqls, err = getResultSqls(o.Db, indexes)
if err != nil {
log.Printf("search index definition error:%s\n", err)
return nil, err
}
for _, sqlContent := range sqls {
if strings.Contains(sqlContent, "CREATE UNIQUE INDEX") {
continue
}
tableDDl += strings.Join(sqls, "")
constraintsCondition := fmt.Sprintf("d.datname = '%s' AND n.nspname = '%s' AND C.relname = '%s'",
database, schema, tableName)
if o.IsCaseSensitive {
constraintsCondition = fmt.Sprintf("lower(d.datname) = '%s' AND lower(n.nspname) = '%s' "+
"AND lower(C.relname) = '%s'", database, schema, tableName)
}
// 获取所有约束
constraints := fmt.Sprintf("SELECT 'CONSTRAINT ' || r.conname || ' ' || "+
" pg_catalog.pg_get_constraintdef ( r.OID, TRUE ) AS constraint_definition "+
" FROM pg_catalog.pg_constraint r "+
" JOIN pg_catalog.pg_class C ON C.OID = r.conrelid "+
" JOIN pg_catalog.pg_namespace n ON n.OID = C.relnamespace "+
" JOIN pg_catalog.pg_database d ON d.datname = n.nspname "+
" WHERE %s", constraintsCondition)
sqls, err = getResultSqls(o.Db, constraints)
if err != nil {
log.Printf("search constraint definition error:%s\n", err)
return nil, err
}
for _, sqlContext := range sqls {
tableDDl += ",\n" + sqlContext
}
tableDDl += ")"
indexesCondition := fmt.Sprintf("schemaname = '%s' and tablename = '%s' ", schema, tableName)
if o.IsCaseSensitive {
indexesCondition = fmt.Sprintf("lower(schemaname) = '%s' and lower(tablename) = '%s'",
schema, tableName)
}
// 获取索引
indexes := fmt.Sprintf("SELECT indexdef AS index_definition FROM pg_indexes "+
" WHERE %s", indexesCondition)
sqls, err = getResultSqls(o.Db, indexes)
if err != nil {
log.Printf("search index definition error:%s\n", err)
return nil, err
}
for _, sqlContent := range sqls {
if strings.Contains(sqlContent, "CREATE UNIQUE INDEX") {
continue
}
tableDDl += ";\n" + sqlContent
}
tables = append(tables, tableDDl)
tableDDl += ";\n" + sqlContent
}
tables = append(tables, tableDDl)
return tables, nil
}

func (o *DB) ShowCreateViews(database, tableName string) ([]string, error) {
func (o *DB) ShowCreateViews(database, schema, tableName string) ([]string, error) {
query := fmt.Sprintf(
"SELECT 'CREATE OR REPLACE VIEW ' || table_schema || '.' || table_name || ' AS ' || view_definition"+
" AS create_view_statement "+
" FROM information_schema.views WHERE table_catalog = '%s' AND table_name = '%s'",
database, tableName)
" FROM information_schema.views "+
" WHERE table_catalog = '%s' AND table_schema = '%s' AND table_name = '%s'",
database, schema, tableName)

if o.IsCaseSensitive {
database = strings.ToLower(database)
tableName = strings.ToLower(tableName)
query = fmt.Sprintf(
"SELECT 'CREATE OR REPLACE VIEW ' || table_schema || '.' || table_name || ' AS ' || view_definition"+
" AS create_view_statement "+
" FROM information_schema.views WHERE lower(table_catalog) = '%s' AND lower(table_name) = '%s'",
database, tableName)
" FROM information_schema.views "+
" WHERE lower(table_catalog) = '%s' AND lower(table_schema) = '%s' AND lower(table_name) = '%s'",
database, schema, tableName)
}
return getResultSqls(o.Db, query)
}
Expand Down
90 changes: 61 additions & 29 deletions sqle/server/auditplan/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -1595,20 +1595,6 @@ func (at *PostgreSQLSchemaMetaTask) collectorDo() {
defer db.Close()
db.IsCaseSensitive = db.GetCaseSensitive()

tables, err := db.ShowSchemaTables(at.ap.InstanceDatabase)
if err != nil {
at.logger.Errorf("get schema table fail, error: %s", err)
return
}
var views []string
if at.ap.Params.GetParam("collect_view").Bool() {
views, err = db.ShowSchemaViews(at.ap.InstanceDatabase)
if err != nil {
at.logger.Errorf("get schema view fail, error: %s", err)
return
}
}

schemas, err := db.GetAllUserSchemas()
if err != nil {
at.logger.Errorf("get database=%s schemas error: %s", at.ap.InstanceDatabase, err)
Expand All @@ -1619,23 +1605,69 @@ func (at *PostgreSQLSchemaMetaTask) collectorDo() {
return
}

sqls := make([]string, 0, len(tables)+len(views))
for _, table := range tables {
tableSqls, err := db.ShowCreateTables(at.ap.InstanceDatabase, table, schemas)
if err != nil {
at.logger.Errorf("show create table fail, error: %s", err)
return
}
sqls = append(sqls, tableSqls...)
wg := sync.WaitGroup{}
wg.Add(len(schemas) * 2)
tableMutex := sync.Mutex{}
viewMutex := sync.Mutex{}
sqls := make([]string, 0)
finalTableSqls := make([]string, 0)
finalViewSqls := make([]string, 0)
for _, schema := range schemas {
go func(schema string) {
defer wg.Done()
tables, err := db.ShowSchemaTables(schema)
if err != nil {
at.logger.Errorf("get schema table fail, error: %s", err)
return
}
for _, table := range tables {
tableSqls, err := db.ShowCreateTables(at.ap.InstanceDatabase, schema, table)
if err != nil {
at.logger.Errorf("show create table fail, error: %s", err)
return
}
tableMutex.Lock()
if len(tableSqls) > 0 {
finalTableSqls = append(finalTableSqls, tableSqls...)
}
tableMutex.Unlock()
}
}(schema)

go func(schema string) {
defer wg.Done()
var views []string
if at.ap.Params.GetParam("collect_view").Bool() {
views, err = db.ShowSchemaViews(schema)
if err != nil {
at.logger.Errorf("get schema view fail, error: %s", err)
return
}
}
for _, view := range views {
viewSqls, err := db.ShowCreateViews(at.ap.InstanceDatabase, schema, view)
if err != nil {
at.logger.Errorf("show create view fail, error: %s", err)
return
}
viewMutex.Lock()
if len(viewSqls) > 0 {
finalViewSqls = append(finalViewSqls, viewSqls...)
}
viewMutex.Unlock()
}
}(schema)
}
for _, view := range views {
viewSqls, err := db.ShowCreateViews(at.ap.InstanceDatabase, view)
if err != nil {
at.logger.Errorf("show create view fail, error: %s", err)
return
}
sqls = append(sqls, viewSqls...)
wg.Wait()

if len(finalTableSqls) > 0 {
sqls = append(sqls, finalTableSqls...)
}

if len(finalViewSqls) > 0 {
sqls = append(sqls, finalViewSqls...)
}

if len(sqls) > 0 {
err = at.persist.OverrideAuditPlanSQLs(at.ap.ID, convertRawSQLToModelSQLs(sqls))
if err != nil {
Expand Down