diff --git a/sqle/driver/mysql/executor/executor.go b/sqle/driver/mysql/executor/executor.go index 8f504fae30..841e9ca726 100644 --- a/sqle/driver/mysql/executor/executor.go +++ b/sqle/driver/mysql/executor/executor.go @@ -381,14 +381,28 @@ func (c *Executor) ShowDatabases(ignoreSysDatabase bool) ([]string, error) { +------------+------------------+ */ func (c *Executor) ShowSchemaTables(schema string) ([]string, error) { - query := fmt.Sprintf( - "select TABLE_NAME from information_schema.tables where table_schema='%s' and TABLE_TYPE in ('BASE TABLE','SYSTEM VIEW')", schema) + return c.showSchemaObjects(schema, "'BASE TABLE'", "'SYSTEM VIEW'") +} + +func (c *Executor) ShowAllSchemaObjects(schema string) ([]string, error) { + return c.showSchemaObjects(schema) +} +func (c *Executor) showSchemaObjects(schema string, objects ...string) ([]string, error) { + var query string if c.IsLowerCaseTableNames() { schema = strings.ToLower(schema) - query = fmt.Sprintf( - "select TABLE_NAME from information_schema.tables where lower(table_schema)='%s' and TABLE_TYPE in ('BASE TABLE','SYSTEM VIEW')", schema) - + if len(objects) == 0 { + query = fmt.Sprintf("select TABLE_NAME from information_schema.tables where lower(table_schema)='%s'", schema) + } else { + query = fmt.Sprintf("select TABLE_NAME from information_schema.tables where lower(table_schema)='%s' and TABLE_TYPE in (%s)", schema, strings.Join(objects, ",")) + } + } else { + if len(objects) == 0 { + query = fmt.Sprintf("select TABLE_NAME from information_schema.tables where table_schema='%s'", schema) + } else { + query = fmt.Sprintf("select TABLE_NAME from information_schema.tables where table_schema='%s' and TABLE_TYPE in (%s)", schema, strings.Join(objects, ",")) + } } result, err := c.Db.Query(query) if err != nil { diff --git a/sqle/driver/mysql/session/context.go b/sqle/driver/mysql/session/context.go index 4cbc502561..ef7687e310 100644 --- a/sqle/driver/mysql/session/context.go +++ b/sqle/driver/mysql/session/context.go @@ -465,7 +465,7 @@ func (c *Context) IsTableExist(stmt *ast.TableName) (bool, error) { return false, nil } - tables, err := c.e.ShowSchemaTables(schemaName) + tables, err := c.e.ShowAllSchemaObjects(schemaName) if err != nil { return false, err }