Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions sqle/driver/mysql/audit.go
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ func (i *MysqlDriverImpl) checkInvalidCreateTable(stmt *ast.CreateTableStmt) err
if !schemaExist {
i.result.Add(driverV2.RuleLevelError, "", plocale.Bundle.LocalizeAll(plocale.SchemaNotExistMessage), schemaName)
} else {
tableExist, err := i.Ctx.IsTableExist(stmt.Table)
tableExist, err := i.Ctx.IsTableOrViewExist(stmt.Table)
if err != nil {
return err
}
Expand All @@ -120,7 +120,7 @@ func (i *MysqlDriverImpl) checkInvalidCreateTable(stmt *ast.CreateTableStmt) err
i.getTableName(stmt.Table))
}
if stmt.ReferTable != nil {
referTableExist, err := i.Ctx.IsTableExist(stmt.ReferTable)
referTableExist, err := i.Ctx.IsTableOrViewExist(stmt.ReferTable)
if err != nil {
return err
}
Expand Down Expand Up @@ -531,7 +531,7 @@ func (i *MysqlDriverImpl) checkInvalidDropTable(stmt *ast.DropTableStmt) error {
if !schemaExist {
needExistsSchemasName = append(needExistsSchemasName, schemaName)
} else {
tableExist, err := i.Ctx.IsTableExist(table)
tableExist, err := i.Ctx.IsTableOrViewExist(table)
if err != nil {
return err
}
Expand Down Expand Up @@ -896,7 +896,7 @@ func (i *MysqlDriverImpl) checkInvalidUpdate(stmt *ast.UpdateStmt) error {
if !schemaExist {
needExistsSchemasName = append(needExistsSchemasName, schemaName)
} else {
tableExist, err := i.Ctx.IsTableExist(table)
tableExist, err := i.Ctx.IsTableOrViewExist(table)
if err != nil {
return err
}
Expand Down Expand Up @@ -1024,7 +1024,7 @@ func (i *MysqlDriverImpl) checkInvalidDelete(stmt *ast.DeleteStmt) error {
if !schemaExist {
needExistsSchemasName = append(needExistsSchemasName, schemaName)
} else {
tableExist, err := i.Ctx.IsTableExist(table)
tableExist, err := i.Ctx.IsTableOrViewExist(table)
if err != nil {
return err
}
Expand Down Expand Up @@ -1135,7 +1135,7 @@ func (i *MysqlDriverImpl) checkInvalidSelect(stmt *ast.SelectStmt) error {
if !schemaExist {
needExistsSchemasName = append(needExistsSchemasName, schemaName)
} else {
tableExist, err := i.Ctx.IsTableExist(table)
tableExist, err := i.Ctx.IsTableOrViewExist(table)
if err != nil {
return err
}
Expand Down
25 changes: 24 additions & 1 deletion sqle/driver/mysql/audit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,22 @@ func DefaultMysqlInspect() *MysqlDriverImpl {
}
}

// setupMockViewQueryExpectations sets up mock expectations for view queries to avoid test failures
func setupMockViewQueryExpectations(handler sqlmock.Sqlmock) {
handler.MatchExpectationsInOrder(false)
// Set up expectations that can be matched multiple times
// Use regexp to match both exact and lowercase schema queries
handler.ExpectQuery(regexp.QuoteMeta(`select TABLE_NAME from information_schema.tables where table_schema='exist_db' and TABLE_TYPE='VIEW'`)).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
handler.ExpectQuery(regexp.QuoteMeta(`select TABLE_NAME from information_schema.tables where lower(table_schema)='exist_db' and TABLE_TYPE='VIEW'`)).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
// Add additional expectations for potential multiple calls
handler.ExpectQuery(regexp.QuoteMeta(`select TABLE_NAME from information_schema.tables where table_schema='exist_db' and TABLE_TYPE='VIEW'`)).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
handler.ExpectQuery(regexp.QuoteMeta(`select TABLE_NAME from information_schema.tables where lower(table_schema)='exist_db' and TABLE_TYPE='VIEW'`)).
WillReturnRows(sqlmock.NewRows([]string{"TABLE_NAME"}))
}

func NewMockInspect(e *executor.Executor) *MysqlDriverImpl {
log.Logger().SetLevel(logrus.ErrorLevel)
return &MysqlDriverImpl{
Expand Down Expand Up @@ -150,6 +166,9 @@ func AIMockExecutor(expectations []*AIMockSQLExpectation) (*executor.Executor, e
}
handler.MatchExpectationsInOrder(false)

// 为视图查询添加 mock 期望
setupMockViewQueryExpectations(handler)

for _, exp := range expectations {
handler.ExpectQuery(regexp.QuoteMeta(exp.Query)).
WillReturnRows(exp.Rows)
Expand Down Expand Up @@ -5485,6 +5504,8 @@ func TestDDLRecommendTableColumnCharsetSame(t *testing.T) {
// 需要查询数据库 获取数据库默认字符集
e, handler, err := executor.NewMockExecutor()
assert.NoError(t, err)
// 为视图查询添加 mock 期望
setupMockViewQueryExpectations(handler)
inspect1 := NewMockInspect(e)

// 不触发规则
Expand Down Expand Up @@ -7961,8 +7982,10 @@ func TestDDLCheckCharLength(t *testing.T) {
rule := rulepkg.RuleHandlerMap[rulepkg.DDLCheckCharLength].Rule
for _, arg := range args {
rule.Params.SetParamValue(rulepkg.DefaultSingleParamKeyName, arg.Param)
e, _, err := executor.NewMockExecutor()
e, handler, err := executor.NewMockExecutor()
assert.NoError(t, err)
// 为视图查询添加 mock 期望
setupMockViewQueryExpectations(handler)
inspect := NewMockInspect(e)

t.Run(arg.Name, func(t *testing.T) {
Expand Down
144 changes: 144 additions & 0 deletions sqle/driver/mysql/session/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,11 @@ type TableInfo struct {
Selectivity map[string] /*column name or index name*/ float64 /*selectivity*/
}

type ViewInfo struct {
// isLoad indicate whether ViewInfo load from database or not.
isLoad bool
}

type SchemaInfo struct {
DefaultEngine string
engineLoad bool
Expand All @@ -58,6 +63,7 @@ type SchemaInfo struct {
collationLoad bool
IsRealSchema bool // issue #1832, 判断当前的 schema 是否真实存在于数据库中.
Tables map[string]*TableInfo
Views map[string]*ViewInfo
}

type HistorySQLInfo struct {
Expand Down Expand Up @@ -118,6 +124,7 @@ func NewContext(parent *Context, opts ...contextOption) *Context {
for schemaName, schema := range parent.schemas {
newSchema := &SchemaInfo{
Tables: map[string]*TableInfo{},
Views: map[string]*ViewInfo{},
}
if schema == nil || schema.Tables == nil {
continue
Expand All @@ -132,6 +139,13 @@ func NewContext(parent *Context, opts ...contextOption) *Context {
AlterTables: table.AlterTables,
}
}
if schema.Views != nil {
for viewName, view := range schema.Views {
newSchema.Views[viewName] = &ViewInfo{
isLoad: view.isLoad,
}
}
}
ctx.schemas[schemaName] = newSchema
}

Expand Down Expand Up @@ -216,6 +230,7 @@ func (c *Context) addSchema(name string) {
}
c.schemas[name] = &SchemaInfo{
Tables: map[string]*TableInfo{},
Views: map[string]*ViewInfo{},
}
}

Expand Down Expand Up @@ -303,6 +318,54 @@ func (c *Context) delTable(schemaName, tableName string) {
delete(schema.Tables, tableName)
}

func (c *Context) hasLoadViews(schemaName string) (hasLoad bool) {
if schema, ok := c.getSchema(schemaName); ok {
if schema.Views == nil {
hasLoad = false
} else {
hasLoad = true
}
}
return
}

func (c *Context) loadViews(schemaName string, viewsName []string) {
schema, ok := c.getSchema(schemaName)
if !ok {
return
}
if c.hasLoadViews(schemaName) {
return
}
if schema.Views == nil {
schema.Views = map[string]*ViewInfo{}
}
isLowerCaseTableName := c.IsLowerCaseTableName()
for _, name := range viewsName {
if isLowerCaseTableName {
name = strings.ToLower(name)
}
schema.Views[name] = &ViewInfo{
isLoad: true,
}
}
}

func (c *Context) hasView(schemaName, viewName string) (has bool) {
schema, SchemaExist := c.getSchema(schemaName)
if !SchemaExist {
return false
}
if !c.hasLoadViews(schemaName) {
return false
}
if c.IsLowerCaseTableName() {
viewName = strings.ToLower(viewName)
}
_, has = schema.Views[viewName]
return
}

func (c *Context) SetCurrentSchema(schema string) {
if c.IsLowerCaseTableName() {
schema = strings.ToLower(schema)
Expand Down Expand Up @@ -493,6 +556,87 @@ func (c *Context) IsTableExist(stmt *ast.TableName) (bool, error) {
return c.hasTable(schemaName, stmt.Name.String()), nil
}

// IsTableOrViewExist check table or view is exist or not.
func (c *Context) IsTableOrViewExist(stmt *ast.TableName) (bool, error) {
Comment thread
iwanghc marked this conversation as resolved.
schemaName := c.GetSchemaName(stmt)
schemaExist, err := c.IsSchemaExist(schemaName)
if err != nil {
return schemaExist, err
}
if !schemaExist {
return false, nil
}

if !c.hasLoadTables(schemaName) {
if c.e == nil {
return false, nil
}

tables, err := c.e.ShowSchemaTables(schemaName)
if err != nil {
return false, err
}
c.loadTables(schemaName, tables)
}

lowerCaseTableNames, err := c.GetSystemVariable(SysVarLowerCaseTableNames)
if err != nil {
return false, err
}

var tableExist bool
if lowerCaseTableNames != "0" {
Comment thread
iwanghc marked this conversation as resolved.
capitalizedTable := make(map[string]struct{})
schemaInfo, ok := c.getSchema(schemaName)
if !ok {
return false, fmt.Errorf("schema %s not exist", schemaName)
}

for name := range schemaInfo.Tables {
capitalizedTable[strings.ToUpper(name)] = struct{}{}
}
_, tableExist = capitalizedTable[strings.ToUpper(stmt.Name.String())]
} else {
tableExist = c.hasTable(schemaName, stmt.Name.String())
}

// If table exists, return true
if tableExist {
return true, nil
}

// If table doesn't exist, check if it's a view
if !c.hasLoadViews(schemaName) {
if c.e == nil {
Comment thread
iwanghc marked this conversation as resolved.
return false, nil
}

views, err := c.e.ShowSchemaViews(schemaName)
if err != nil {
return false, err
}
c.loadViews(schemaName, views)
}

var viewExist bool
if lowerCaseTableNames != "0" {
capitalizedView := make(map[string]struct{})
schemaInfo, ok := c.getSchema(schemaName)
if !ok {
return false, fmt.Errorf("schema %s not exist", schemaName)
}

for name := range schemaInfo.Views {
capitalizedView[strings.ToUpper(name)] = struct{}{}
}
_, viewExist = capitalizedView[strings.ToUpper(stmt.Name.String())]
} else {
viewExist = c.hasView(schemaName, stmt.Name.String())
}

return viewExist, nil
}

const (
SysVarLowerCaseTableNames = "lower_case_table_names"
)
Expand Down
Loading