diff --git a/migrator.go b/migrator.go index 1ef7940..dbfb014 100644 --- a/migrator.go +++ b/migrator.go @@ -396,7 +396,8 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`) { _, schemaName, tableName := splitFullQualifiedName(stmt.Table) - query := "SELECT c.COLUMN_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c ON c.CONSTRAINT_NAME=t.CONSTRAINT_NAME WHERE t.CONSTRAINT_TYPE IN ('PRIMARY KEY', 'UNIQUE') AND c.TABLE_CATALOG = ? AND c.TABLE_NAME = ?" + + query := "SELECT t.CONSTRAINT_NAME FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c ON c.CONSTRAINT_NAME=t.CONSTRAINT_NAME WHERE t.CONSTRAINT_TYPE IN ('PRIMARY KEY', 'UNIQUE') AND c.TABLE_CATALOG = ? AND c.TABLE_NAME = ? AND CONSTRAINT_TYPE = ?" queryParameters := []interface{}{m.CurrentDatabase(), tableName} @@ -404,15 +405,36 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`) query += " AND c.TABLE_SCHEMA = ?" queryParameters = append(queryParameters, schemaName) } - + queryParameters = append(queryParameters, "UNIQUE") columnTypeRows, err := m.DB.Raw(query, queryParameters...).Rows() if err != nil { return err } + uniqueContraints := map[string]int{} + for columnTypeRows.Next() { + var constraintName string + columnTypeRows.Scan(&constraintName) + uniqueContraints[constraintName]++ + } + _ = columnTypeRows.Close() + + query = "SELECT c.COLUMN_NAME, t.CONSTRAINT_NAME, t.CONSTRAINT_TYPE FROM INFORMATION_SCHEMA.TABLE_CONSTRAINTS t JOIN INFORMATION_SCHEMA.CONSTRAINT_COLUMN_USAGE c ON c.CONSTRAINT_NAME=t.CONSTRAINT_NAME WHERE t.CONSTRAINT_TYPE IN ('PRIMARY KEY', 'UNIQUE') AND c.TABLE_CATALOG = ? AND c.TABLE_NAME = ?" + + queryParameters = []interface{}{m.CurrentDatabase(), tableName} + + if schemaName != "" { + query += " AND c.TABLE_SCHEMA = ?" + queryParameters = append(queryParameters, schemaName) + } + + columnTypeRows, err = m.DB.Raw(query, queryParameters...).Rows() + if err != nil { + return err + } for columnTypeRows.Next() { - var name, columnType string - _ = columnTypeRows.Scan(&name, &columnType) + var name, constraintName, columnType string + _ = columnTypeRows.Scan(&name, &constraintName, &columnType) for idx, c := range columnTypes { mc := c.(migrator.ColumnType) if mc.NameValue.String == name { @@ -420,7 +442,9 @@ WHERE TABLE_CATALOG = ? AND TABLE_NAME = ?`) case "PRIMARY KEY": mc.PrimaryKeyValue = sql.NullBool{Bool: true, Valid: true} case "UNIQUE": - mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} + if uniqueContraints[constraintName] == 1 { + mc.UniqueValue = sql.NullBool{Bool: true, Valid: true} + } } columnTypes[idx] = mc break @@ -573,17 +597,29 @@ func (m Migrator) HasConstraint(value interface{}, name string) bool { return m.DB.Raw( `SELECT count(*) FROM ( + -- Check CHECK constraint SELECT C.name, T.name as table_name FROM sys.check_constraints as C INNER JOIN sys.tables as T on C.parent_object_id=T.object_id INNER JOIN INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE C.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ? UNION + -- Check foreign key constraints SELECT FK.name, T.name as table_name FROM sys.foreign_keys as FK INNER JOIN sys.tables as T on FK.parent_object_id=T.object_id INNER JOIN INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name WHERE FK.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ? + UNION + -- Check Unique Constraint + SELECT UK.name, T.name as table_name FROM sys.key_constraints as UK + INNER JOIN sys.tables as T on UK.parent_object_id = T.object_id + INNER JOIN INFORMATION_SCHEMA.TABLES as I on I.TABLE_NAME = T.name + WHERE UK.type = 'UQ' AND UK.name = ? AND I.TABLE_NAME = ? AND I.TABLE_SCHEMA like ? AND I.TABLE_CATALOG = ? ) as constraints;`, + // CHECK constraint parameter + name, tableName, tableSchema, tableCatalog, + // Foreign key constraint parameters name, tableName, tableSchema, tableCatalog, + // Unique Constraint Parameter name, tableName, tableSchema, tableCatalog, ).Row().Scan(&count) }) diff --git a/migrator_test.go b/migrator_test.go index 0bab9d9..76edd38 100644 --- a/migrator_test.go +++ b/migrator_test.go @@ -208,6 +208,36 @@ type TestTableFieldCommentUpdate struct { func (*TestTableFieldCommentUpdate) TableName() string { return "test_table_field_comment" } +type TestTableFieldUnique struct { + ID string `gorm:"column:id;primaryKey;comment:"` // field comment is an empty string + Name string `gorm:"column:name;unique;comment:姓名"` + Age uint `gorm:"column:age;comment:年龄"` +} + +func (*TestTableFieldUnique) TableName() string { return "test_table_field_unique" } + +func TestHasConstraint(t *testing.T) { + db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) + if err != nil { + t.Fatal(err) + } + dm := db.Debug().Migrator() + tableModel := new(TestTableFieldUnique) + defer func() { + if err = dm.DropTable(tableModel); err != nil { + t.Errorf("couldn't drop table %q, got error: %v", tableModel.TableName(), err) + } + }() + + if err = dm.AutoMigrate(tableModel); err != nil { + t.Fatal(err) + } + hasName := dm.HasConstraint(tableModel, "name") + if !hasName { + t.Fatalf("expected unique") + } +} + func TestMigrator_MigrateColumnComment(t *testing.T) { db, err := gorm.Open(sqlserver.Open(sqlserverDSN)) if err != nil {