diff --git a/dumpling/export/BUILD.bazel b/dumpling/export/BUILD.bazel index 693064384315f..a86111a9cbe1f 100644 --- a/dumpling/export/BUILD.bazel +++ b/dumpling/export/BUILD.bazel @@ -21,6 +21,7 @@ go_library( "task.go", "util.go", "writer.go", + "writer_parquet.go", "writer_util.go", ], importpath = "github.com/pingcap/tidb/dumpling/export", @@ -65,6 +66,13 @@ go_library( "@com_github_tikv_pd_client//:client", "@com_github_tikv_pd_client//http", "@com_github_tikv_pd_client//pkg/caller", + "@com_github_xitongsys_parquet_go//layout", + "@com_github_xitongsys_parquet_go//marshal", + "@com_github_xitongsys_parquet_go//parquet", + "@com_github_xitongsys_parquet_go//schema", + "@com_github_xitongsys_parquet_go//types", + "@com_github_xitongsys_parquet_go//writer", + "@com_github_xitongsys_parquet_go_source//buffer", "@io_etcd_go_etcd_client_v3//:client", "@org_golang_x_sync//errgroup", "@org_uber_go_atomic//:atomic", @@ -91,6 +99,7 @@ go_test( "status_test.go", "util_for_test.go", "util_test.go", + "writer_parquet_test.go", "writer_serial_test.go", "writer_test.go", ], @@ -118,6 +127,9 @@ go_test( "@com_github_prometheus_client_golang//prometheus/collectors", "@com_github_spf13_pflag//:pflag", "@com_github_stretchr_testify//require", + "@com_github_xitongsys_parquet_go//reader", + "@com_github_xitongsys_parquet_go//types", + "@com_github_xitongsys_parquet_go_source//local", "@com_github_tikv_pd_client//:client", "@com_github_tikv_pd_client//clients/gc", "@org_golang_x_sync//errgroup", diff --git a/dumpling/export/config.go b/dumpling/export/config.go index 26f20ca42fef1..32354a79041c4 100644 --- a/dumpling/export/config.go +++ b/dumpling/export/config.go @@ -83,6 +83,9 @@ const ( flagClusterSSLCA = "cluster-ssl-ca" flagClusterSSLCert = "cluster-ssl-cert" flagClusterSSLKey = "cluster-ssl-key" + flagParquetCompress = "parquet-compress" + flagParquetPageSize = "parquet-page-size" + flagParquetRowGroupSize = "parquet-row-group-size" // FlagHelp represents the help flag FlagHelp = "help" @@ -203,11 +206,27 @@ type Config struct { PDAddr string // ClusterSSLCA/ClusterSSLCert/ClusterSSLKey override Security.* when connecting // to PD endpoints for GC control. - ClusterSSLCA string - ClusterSSLCert string - ClusterSSLKey string + ClusterSSLCA string + ClusterSSLCert string + ClusterSSLKey string + ParquetCompressType ParquetCompressType + ParquetPageSize int64 + ParquetRowGroupSize int64 } +type ParquetCompressType string + +const ( + // NoCompression won't compress given bytes. + NoCompression ParquetCompressType = "no-compression" + // Gzip will compress given bytes in gzip format. + Gzip ParquetCompressType = "gz" + // Snappy will compress given bytes in snappy format. + Snappy ParquetCompressType = "snappy" + // Zstd will compress given bytes in zstd format. + Zstd ParquetCompressType = "zst" +) + // ServerInfoUnknown is the unknown database type to dumpling var ServerInfoUnknown = version.ServerInfo{ ServerType: version.ServerTypeUnknown, @@ -352,7 +371,7 @@ func (*Config) DefineFlags(flags *pflag.FlagSet) { "If not specified, dumpling will dump table without inner-concurrency which could be relatively slow. default unlimited") flags.String(flagWhere, "", "Dump only selected records") flags.Bool(flagEscapeBackslash, true, "use backslash to escape special characters") - flags.String(flagFiletype, "", "The type of export file (sql/csv)") + flags.String(flagFiletype, "", "The type of export file (sql/csv/parquet)") flags.Bool(flagNoHeader, false, "whether not to dump CSV table header") flags.BoolP(flagNoSchemas, "m", false, "Do not dump table schemas with the data") flags.BoolP(flagNoData, "d", false, "Do not dump table data") @@ -384,6 +403,9 @@ func (*Config) DefineFlags(flags *pflag.FlagSet) { flags.String(flagClusterSSLCA, "", "CA certificate path for TLS connections to PD endpoints used by GC control; if empty, reuse --ca") flags.String(flagClusterSSLCert, "", "Client certificate path for TLS connections to PD endpoints used by GC control; if empty, reuse --cert") flags.String(flagClusterSSLKey, "", "Client private key path for TLS connections to PD endpoints used by GC control; if empty, reuse --key") + flags.String(flagParquetCompress, "snappy", "Compress algorithm for parquet file, support 'no-compression', 'snappy', 'gzip', 'zstd'") + flags.Int64(flagParquetPageSize, 1024*1024, "Parquet page size in bytes") + flags.Int64(flagParquetRowGroupSize, 16*1024*1024, "Parquet row group size in bytes") } // ParseFromFlags parses dumpling's export.Config from flags @@ -625,6 +647,20 @@ func (conf *Config) ParseFromFlags(flags *pflag.FlagSet) error { return errors.Errorf("%s is only supported when dumping whole table to csv, not compatible with %s", flagCsvOutputDialect, conf.FileType) } conf.CsvOutputDialect, err = ParseOutputDialect(dialect) + + parquetCompressType, err := flags.GetString(flagParquetCompress) + if err != nil { + return errors.Trace(err) + } + conf.ParquetCompressType, err = ParseParquetCompressType(parquetCompressType) + if err != nil { + return errors.Trace(err) + } + conf.ParquetPageSize, err = flags.GetInt64(flagParquetPageSize) + if err != nil { + return errors.Trace(err) + } + conf.ParquetRowGroupSize, err = flags.GetInt64(flagParquetRowGroupSize) if err != nil { return errors.Trace(err) } @@ -818,6 +854,10 @@ func adjustFileFormat(conf *Config) error { return errors.Errorf("unsupported config.FileType '%s' when we specify --sql, please unset --filetype or set it to 'csv'", conf.FileType) } case FileFormatCSVString: + case FileFormatParquetString: + if conf.CompressType != storage.NoCompression { + return errors.Errorf("parquet does not support --compress, please unset it or use --parquet-compress instead") + } default: return errors.Errorf("unknown config.FileType '%s'", conf.FileType) } diff --git a/dumpling/export/ir.go b/dumpling/export/ir.go index 2fef4f6e9cab5..c448fe60f155b 100644 --- a/dumpling/export/ir.go +++ b/dumpling/export/ir.go @@ -35,6 +35,15 @@ type TableMeta interface { ShowCreateView() string AvgRowLength() uint64 HasImplicitRowID() bool + ColumnInfos() []*ColumnInfo +} + +type ColumnInfo struct { + Name string + Type string + Nullable bool + Precision int64 + Scale int64 } // SQLRowIter is the iterator on a collection of sql.Row. @@ -57,6 +66,7 @@ type RowReceiverStringer interface { type Stringer interface { WriteToBuffer(*bytes.Buffer, bool) WriteToBufferInCsv(*bytes.Buffer, bool, *csvOption) + GetRawBytes() []sql.RawBytes } // RowReceiver is an interface which represents sql types that support bind address for *sql.Rows diff --git a/dumpling/export/ir_impl.go b/dumpling/export/ir_impl.go index ee75b62185e79..6a13955dd87d6 100644 --- a/dumpling/export/ir_impl.go +++ b/dumpling/export/ir_impl.go @@ -267,6 +267,25 @@ type tableMeta struct { hasImplicitRowID bool } +func (tm *tableMeta) ColumnInfos() []*ColumnInfo { + columnInfos := make([]*ColumnInfo, 0, len(tm.colTypes)) + for _, ct := range tm.colTypes { + nullable, _ := ct.Nullable() + precision, scale, ok := ct.DecimalSize() + if !ok { + precision, scale = 0, 0 + } + columnInfos = append(columnInfos, &ColumnInfo{ + Name: ct.Name(), + Type: ct.DatabaseTypeName(), + Nullable: nullable, + Precision: precision, + Scale: scale, + }) + } + return columnInfos +} + func (tm *tableMeta) ColumnTypes() []string { colTypes := make([]string, len(tm.colTypes)) for i, ct := range tm.colTypes { diff --git a/dumpling/export/sql_type.go b/dumpling/export/sql_type.go index e4b5a6a3458c2..3175ff834646b 100644 --- a/dumpling/export/sql_type.go +++ b/dumpling/export/sql_type.go @@ -230,6 +230,15 @@ func (r *RowReceiverArr) WriteToBufferInCsv(bf *bytes.Buffer, escapeBackslash bo } } +func (r RowReceiverArr) GetRawBytes() []sql.RawBytes { + rawBytes := make([]sql.RawBytes, len(r.receivers)) + for i, receiver := range r.receivers { + receiver.GetRawBytes() + rawBytes[i] = receiver.GetRawBytes()[0] + } + return rawBytes +} + // SQLTypeNumber implements RowReceiverStringer which represents numeric type columns in database type SQLTypeNumber struct { SQLTypeString @@ -253,6 +262,10 @@ func (s SQLTypeNumber) WriteToBufferInCsv(bf *bytes.Buffer, _ bool, opt *csvOpti } } +func (s *SQLTypeNumber) GetRawBytes() []sql.RawBytes { + return []sql.RawBytes{s.RawBytes} +} + // SQLTypeString implements RowReceiverStringer which represents string type columns in database type SQLTypeString struct { sql.RawBytes @@ -285,6 +298,10 @@ func (s *SQLTypeString) WriteToBufferInCsv(bf *bytes.Buffer, escapeBackslash boo } } +func (s *SQLTypeString) GetRawBytes() []sql.RawBytes { + return []sql.RawBytes{s.RawBytes} +} + // SQLTypeBytes implements RowReceiverStringer which represents bytes type columns in database type SQLTypeBytes struct { sql.RawBytes @@ -321,3 +338,7 @@ func (s *SQLTypeBytes) WriteToBufferInCsv(bf *bytes.Buffer, escapeBackslash bool bf.WriteString(opt.nullValue) } } + +func (s *SQLTypeBytes) GetRawBytes() []sql.RawBytes { + return []sql.RawBytes{s.RawBytes} +} diff --git a/dumpling/export/util_for_test.go b/dumpling/export/util_for_test.go index c0fa5d6fac308..2717e3dae2b88 100644 --- a/dumpling/export/util_for_test.go +++ b/dumpling/export/util_for_test.go @@ -160,6 +160,7 @@ type mockTableIR struct { hasImplicitRowID bool rowErr error rows *sql.Rows + columnInfos []*ColumnInfo SQLRowIter } @@ -256,6 +257,10 @@ func (m *mockTableIR) EscapeBackSlash() bool { return m.escapeBackSlash } +func (m *mockTableIR) ColumnInfos() []*ColumnInfo { + return m.columnInfos +} + func newMockTableIR(databaseName, tableName string, data [][]driver.Value, specialComments, colTypes []string) *mockTableIR { return &mockTableIR{ dbName: databaseName, @@ -268,3 +273,21 @@ func newMockTableIR(databaseName, tableName string, data [][]driver.Value, speci SQLRowIter: nil, } } + +func newMockTableIRWithColumnInfo(databaseName, tableName string, data [][]driver.Value, specialComments []string, infos []*ColumnInfo) *mockTableIR { + colTypes := make([]string, len(infos)) + for i, info := range infos { + colTypes[i] = info.Type + } + return &mockTableIR{ + dbName: databaseName, + tblName: tableName, + data: data, + specCmt: specialComments, + selectedField: "*", + selectedLen: len(infos), + colTypes: colTypes, + SQLRowIter: nil, + columnInfos: infos, + } +} diff --git a/dumpling/export/writer.go b/dumpling/export/writer.go index e10a4fd850dd6..0d8dc0f1328d4 100644 --- a/dumpling/export/writer.go +++ b/dumpling/export/writer.go @@ -58,6 +58,8 @@ func NewWriter( sw.fileFmt = FileFormatSQLText case FileFormatCSVString: sw.fileFmt = FileFormatCSV + case FileFormatParquetString: + sw.fileFmt = FileFormatParquet } return sw } @@ -232,7 +234,11 @@ func (w *Writer) WriteTableData(meta TableMeta, ir TableDataIR, currentChunk int func (w *Writer) tryToWriteTableData(tctx *tcontext.Context, meta TableMeta, ir TableDataIR, curChkIdx int) error { conf, format := w.conf, w.fileFmt namer := newOutputFileNamer(meta, curChkIdx, conf.Rows != UnspecifiedSize, conf.FileSize != UnspecifiedSize) - fileName, err := namer.NextName(conf.OutputFileTemplate, w.fileFmt.Extension()) + fileFmtExtension := format.Extension() + if format == FileFormatParquet && conf.ParquetCompressType != NoCompression { + fileFmtExtension = fmt.Sprintf("%s.%s", conf.ParquetCompressType, fileFmtExtension) + } + fileName, err := namer.NextName(conf.OutputFileTemplate, fileFmtExtension) if err != nil { return err } diff --git a/dumpling/export/writer_parquet.go b/dumpling/export/writer_parquet.go new file mode 100644 index 0000000000000..60cb34e2c067a --- /dev/null +++ b/dumpling/export/writer_parquet.go @@ -0,0 +1,378 @@ +// Copyright 2024 PingCAP, Inc. Licensed under Apache-2.0. + +package export + +import ( + "bytes" + "fmt" + "math/big" + "sync" + "time" + + "github.com/pingcap/errors" + "github.com/pingcap/tidb/br/pkg/storage" + "github.com/pingcap/tidb/br/pkg/summary" + tcontext "github.com/pingcap/tidb/dumpling/context" + "github.com/pingcap/tidb/dumpling/log" + "github.com/xitongsys/parquet-go-source/buffer" + "github.com/xitongsys/parquet-go/layout" + "github.com/xitongsys/parquet-go/marshal" + "github.com/xitongsys/parquet-go/parquet" + "github.com/xitongsys/parquet-go/schema" + "github.com/xitongsys/parquet-go/types" + "github.com/xitongsys/parquet-go/writer" + + "go.uber.org/zap" +) + +const ( + TagTemplate = "name=%s, type=%s, repetitiontype=%s" + DefaultCompressionType = parquet.CompressionCodec_ZSTD + parquetMagicNumber = "PAR1" + parquetParallelNumber = 4 +) + +func WriteInsertInParquet( + pCtx *tcontext.Context, + cfg *Config, + meta TableMeta, + tblIR TableDataIR, + w storage.ExternalFileWriter, + metrics *metrics, +) (n uint64, err error) { + fileRowIter := tblIR.Rows() + if !fileRowIter.HasNext() { + return 0, fileRowIter.Error() + } + + parquetLengthLimit := int(cfg.ParquetRowGroupSize) + + bf := pool.Get().(*bytes.Buffer) + if bfCap := bf.Cap(); bfCap < parquetLengthLimit { + bf.Grow(parquetLengthLimit - bfCap) + } + + // parquet need to get more information from tableMeta + writer, err := NewParquetWriter(bf, meta.ColumnInfos(), cfg) + if err != nil { + return 0, errors.Trace(err) + } + wp := newWriterPipe(w, cfg.FileSize, UnspecifiedSize, metrics, cfg.Labels) + // use context.Background here to make sure writerPipe can deplete all the chunks in pipeline + ctx, cancel := tcontext.Background().WithLogger(pCtx.L()).WithCancel() + var wg sync.WaitGroup + wg.Add(1) + go func() { + wp.Run(ctx) + wg.Done() + }() + defer func() { + cancel() + wg.Wait() + }() + + var ( + row = MakeRowReceiver(meta.ColumnTypes()) + counter uint64 + lastCounter uint64 + selectedFields = meta.SelectedField() + ) + + defer func() { + if err != nil { + pCtx.L().Warn("fail to dumping table(chunk), will revert some metrics and start a retry if possible", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName()), + zap.Uint64("finished rows", lastCounter), + zap.Uint64("finished size", wp.finishedFileSize), + log.ShortError(err)) + SubGauge(metrics.finishedRowsGauge, float64(lastCounter)) + SubGauge(metrics.finishedSizeGauge, float64(wp.finishedFileSize)) + } else { + pCtx.L().Debug("finish dumping table(chunk)", + zap.String("database", meta.DatabaseName()), + zap.String("table", meta.TableName()), + zap.Uint64("finished rows", counter), + zap.Uint64("finished size", wp.finishedFileSize)) + summary.CollectSuccessUnit(summary.TotalBytes, 1, wp.finishedFileSize) + summary.CollectSuccessUnit("total rows", 1, counter) + } + }() + + // add magic number length + wp.currentFileSize += uint64(bf.Len()) + // add row to parquet writer, it will flush to buffer when reach row group size + for fileRowIter.HasNext() { + lastBfSize := bf.Len() + if selectedFields != "" { + if err = fileRowIter.Decode(row); err != nil { + return counter, errors.Trace(err) + } + err = writer.WriteRow(*row) + if err != nil { + return counter, errors.Trace(err) + } + } + counter++ + // buffer size only increase when parquet writer flush occurs + wp.currentFileSize += uint64(bf.Len() - lastBfSize) + if bf.Len() >= parquetLengthLimit { + select { + case <-pCtx.Done(): + return counter, pCtx.Err() + case err = <-wp.errCh: + return counter, err + case wp.input <- bf: + bf = pool.Get().(*bytes.Buffer) + if bfCap := bf.Cap(); bfCap < parquetLengthLimit { + bf.Grow(parquetLengthLimit - bfCap) + } + writer.PFile = buffer.BufferFile{ + Writer: bf, + } + AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) + lastCounter = counter + } + } + fileRowIter.Next() + if wp.ShouldSwitchFile() { + break + } + } + + // write remain data and meta file + if err = writer.WriteStop(); err != nil { + return counter, errors.Trace(err) + } + if bf.Len() > 0 { + wp.input <- bf + } + close(wp.input) + <-wp.closed + AddGauge(metrics.finishedRowsGauge, float64(counter-lastCounter)) + lastCounter = counter + if err = fileRowIter.Error(); err != nil { + return counter, errors.Trace(err) + } + return counter, wp.Error() +} + +func getParquetCompress(compress ParquetCompressType) parquet.CompressionCodec { + switch compress { + case NoCompression: + return parquet.CompressionCodec_UNCOMPRESSED + case Gzip: + return parquet.CompressionCodec_GZIP + case Snappy: + return parquet.CompressionCodec_SNAPPY + case Zstd: + return parquet.CompressionCodec_ZSTD + default: + return DefaultCompressionType + } +} + +type SQLWriter struct { + writer.ParquetWriter +} + +func NewParquetWriter(bf *bytes.Buffer, columns []*ColumnInfo, cfg *Config) (*SQLWriter, error) { + compress := getParquetCompress(cfg.ParquetCompressType) + pageSize := cfg.ParquetPageSize + rowGroupSize := cfg.ParquetRowGroupSize + + res := new(SQLWriter) + res.PFile = buffer.BufferFile{ + Writer: bf, + } + schemaHandler, err := NewSchemaHandlerFromSQL(columns) + if err != nil { + return nil, err + } + res.SchemaHandler = schemaHandler + res.NP = parquetParallelNumber + res.PageSize = pageSize + res.RowGroupSize = rowGroupSize + res.CompressionType = compress + res.PagesMapBuf = make(map[string][]*layout.Page) + res.DictRecs = make(map[string]*layout.DictRecType) + res.ColumnIndexes = make([]*parquet.ColumnIndex, 0) + res.OffsetIndexes = make([]*parquet.OffsetIndex, 0) + res.MarshalFunc = marshal.MarshalCSV + // footer + res.Footer = parquet.NewFileMetaData() + res.Footer.Version = 1 + res.Footer.Schema = append(res.Footer.Schema, res.SchemaHandler.SchemaElements...) + // add magic number + magicNumber := []byte(parquetMagicNumber) + bf.Write([]byte(parquetMagicNumber)) + res.Offset = int64(len(magicNumber)) + return res, err +} + +// WriteRow writes a row to the parquet format +func (w *SQLWriter) WriteRow(r RowReceiverArr) error { + var err error + sqlRaws := r.GetRawBytes() + rec := make([]any, len(sqlRaws)) + for i := 0; i < len(sqlRaws); i++ { + rec[i] = nil + if sqlRaws[i] != nil { + rec[i], err = convertDataToParquet(sqlRaws[i], + w.SchemaHandler.SchemaElements[i+1].Type, + w.SchemaHandler.SchemaElements[i+1].LogicalType) + if err != nil { + return err + } + } + } + return w.Write(rec) +} + +// convertDataToParquet some codes references parquet-go types.StrToParquetType() +func convertDataToParquet(data []byte, pT *parquet.Type, lT *parquet.LogicalType) (any, error) { + s := string(data) + // convert with logical type + if lT != nil && lT.DECIMAL != nil { + numSca := big.NewFloat(1.0) + for i := 0; i < int(lT.DECIMAL.Scale); i++ { + numSca.Mul(numSca, big.NewFloat(10)) + } + num := new(big.Float) + num.SetString(s) + num.Mul(num, numSca) + + if *pT == parquet.Type_INT32 { + tmp, _ := num.Int64() + return int32(tmp), nil + } else if *pT == parquet.Type_INT64 { + tmp, _ := num.Int64() + return tmp, nil + } else if *pT == parquet.Type_FIXED_LEN_BYTE_ARRAY { + s = num.Text('f', 0) + // only used by unsigned big int + res := types.StrIntToBinary(s, "BigEndian", 9, true) + return res, nil + } + } + // TiDB may return invalid timestamp/datetime such as "0000-00-00 00:00:00" or "0001-00-00 00:00:00" + // In this case, we will return nil directly + if lT != nil && lT.TIMESTAMP != nil { + s, err := time.Parse(time.DateTime, s) + if err != nil { + return nil, nil + } + return s.UnixMicro(), nil + } + // convert with primitive type + if *pT == parquet.Type_BOOLEAN { + var v bool + _, err := fmt.Sscanf(s, "%t", &v) + return v, err + } else if *pT == parquet.Type_INT32 { + var v int32 + _, err := fmt.Sscanf(s, "%d", &v) + return v, err + } else if *pT == parquet.Type_INT64 { + var v int64 + _, err := fmt.Sscanf(s, "%d", &v) + return v, err + } else if *pT == parquet.Type_FLOAT { + var v float32 + _, err := fmt.Sscanf(s, "%f", &v) + return v, err + } else if *pT == parquet.Type_DOUBLE { + var v float64 + _, err := fmt.Sscanf(s, "%f", &v) + return v, err + } else if *pT == parquet.Type_BYTE_ARRAY { + return s, nil + + } else if *pT == parquet.Type_FIXED_LEN_BYTE_ARRAY { + return s, nil + } + return nil, fmt.Errorf("unsupported type %v", pT) + +} + +// NewSchemaHandlerFromSQL build tag and reuse the NewSchemaHandlerFromMetadata to create the schema handler +func NewSchemaHandlerFromSQL(columns []*ColumnInfo) (*schema.SchemaHandler, error) { + tags := make([]string, 0, len(columns)) + for _, column := range columns { + primitive, logical := ToParquetType(column) + repetitionType := parquet.FieldRepetitionType_REQUIRED + // set TIMESTAMP and DATETIME to optional because we will set null for those invalid values + // such as: 0000-00-00 00:00:00 or 0001-00-00 00:00:00 + if column.Nullable || column.Type == "TIMESTAMP" || column.Type == "DATETIME" { + repetitionType = parquet.FieldRepetitionType_OPTIONAL + } + tag := fmt.Sprintf(TagTemplate, + column.Name, + primitive.String(), + repetitionType.String()) + if logical != "" { + tag = fmt.Sprintf("%s, %s", tag, logical) + } + tags = append(tags, tag) + } + // use tag to define the schema + schemaHandler, err := schema.NewSchemaHandlerFromMetadata(tags) + if err != nil { + return nil, err + } + return schemaHandler, nil +} + +// ToParquetType converts database type to parquet type +func ToParquetType(columnInfo *ColumnInfo) (parquet.Type, string) { + // TiDB Type + columnType := columnInfo.Type + switch columnType { + case "CHAR", "VARCHAR", "DATE", "TIME", "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "SET", "JSON", "VECTOR": + return parquet.Type_BYTE_ARRAY, "logicaltype=STRING" + case "ENUM": + return parquet.Type_BYTE_ARRAY, "logicaltype=STRING" + case "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BINARY", "VARBINARY", "BIT": + return parquet.Type_BYTE_ARRAY, "" + case "TIMESTAMP", "DATETIME": + return parquet.Type_INT64, "logicaltype=TIMESTAMP, logicaltype.isadjustedtoutc=false, logicaltype.unit=MICROS" + case "YEAR", "TINYINT", "SMALLINT", "MEDIUMINT", "UNSIGNED TINYINT", "UNSIGNED SMALLINT", "UNSIGNED MEDIUMINT", "INT": + return parquet.Type_INT32, "" + case "BIGINT", "UNSIGNED INT": + return parquet.Type_INT64, "" + case "DECIMAL": + // add converted type for backward compatibility + logicalType := fmt.Sprintf("logicaltype=DECIMAL, logicaltype.precision=%d, logicaltype.scale=%d, convertedtype=DECIMAL, precision=%d, scale=%d", + columnInfo.Precision, columnInfo.Scale, columnInfo.Precision, columnInfo.Scale) + if columnInfo.Precision <= 9 { + return parquet.Type_INT32, logicalType + } + // int64 has 19 digits, so it can store 18 digits decimal + if columnInfo.Precision <= 18 { + return parquet.Type_INT64, logicalType + } + return parquet.Type_BYTE_ARRAY, "logicaltype=STRING" + case "UNSIGNED BIGINT": + // add converted type for backward compatibility + return parquet.Type_FIXED_LEN_BYTE_ARRAY, "length=9, logicaltype=DECIMAL, logicaltype.precision=20, logicaltype.scale=0, convertedtype=DECIMAL, precision=20, scale=0" + case "FLOAT": + return parquet.Type_FLOAT, "" + case "DOUBLE": + return parquet.Type_DOUBLE, "" + } + + // Other database, like MariaDB + switch columnType { + case "NCHAR", "NVARCHAR", "CHARACTER", "VARCHARACTER", "SQL_TSI_YEAR", "NULL", "VAR_STRING", "GEOMETRY", "LONG": + return parquet.Type_BYTE_ARRAY, "" + case "INTEGER", "INT1", "INT2", "INT3": + return parquet.Type_INT32, "" + case "INT8": + return parquet.Type_INT64, "" + case "BOOL", "BOOLEAN": + return parquet.Type_BOOLEAN, "" + case "REAL", "DOUBLE", "DOUBLE PRECISION", "NUMERIC", "FIXED": + return parquet.Type_DOUBLE, "" + } + return parquet.Type_BYTE_ARRAY, "" +} diff --git a/dumpling/export/writer_parquet_test.go b/dumpling/export/writer_parquet_test.go new file mode 100644 index 0000000000000..a96efa6822def --- /dev/null +++ b/dumpling/export/writer_parquet_test.go @@ -0,0 +1,414 @@ +package export + +import ( + "database/sql/driver" + "encoding/hex" + "fmt" + "path" + "reflect" + "strconv" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/require" + "github.com/xitongsys/parquet-go-source/local" + "github.com/xitongsys/parquet-go/reader" + "github.com/xitongsys/parquet-go/types" +) + +func TestWriteTableDataWithParquet(t *testing.T) { + dir := t.TempDir() + config := defaultConfigForTest(t) + config.FileType = "parquet" + config.OutputDirPath = dir + config.ParquetPageSize = 1024 * 1024 + config.ParquetCompressType = NoCompression + // mock data + row := make([]driver.Value, 0, len(allTypesColumnInfos)) + for _, c := range allTypesColumnInfos { + switch c.Type { + case "CHAR", "VARCHAR", "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT": + row = append(row, "byte_array_string") + case "DATE": + row = append(row, "1977-01-01") + case "TIME": + row = append(row, "23:59:59") + case "JSON": + row = append(row, "{\"a\": 1, \"b\": \"2\"}") + case "VECTOR": + row = append(row, "[1,2,3]") + case "ENUM": + row = append(row, "byte_array_enum") + case "SET": + row = append(row, "a,b") + case "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BINARY", "VARBINARY": + byteValue, _ := hex.DecodeString("1520c5") + row = append(row, byteValue) + case "BIT": + byteValue := []byte{0x55} + row = append(row, byteValue) + case "TIMESTAMP": + row = append(row, "1973-12-30 15:30:00") + case "DATETIME": + row = append(row, "9999-12-31 23:59:59") + case "TINYINT", "SMALLINT", "MEDIUMINT", "UNSIGNED TINYINT", "UNSIGNED SMALLINT", "UNSIGNED MEDIUMINT", "YEAR": + row = append(row, "1995") + case "INT": + row = append(row, "-2147483648") + case "UNSIGNED INT": + row = append(row, "4294967295") + case "BIGINT": + row = append(row, "-9223372036854775808") + case "UNSIGNED BIGINT": + row = append(row, "18446744073709551615") + case "FLOAT", "DOUBLE": + row = append(row, "123.123") + case "DECIMAL": + if c.Precision == 9 { + row = append(row, "12345678.9") + } else if c.Precision == 18 { + row = append(row, "12345678912345678.9") + } else if c.Precision == 38 { + row = append(row, "1234567890123456789012345678901234567.8") + } else if c.Precision == 40 { + row = append(row, "123456789012345678901234567890123456780.9") + } else { + t.FailNow() + } + default: + t.FailNow() + } + } + data := [][]driver.Value{row} + // parquet need more information to write, so we need to provide columnInfos + tableIR := newMockTableIRWithColumnInfo("test", "employee", data, nil, allTypesColumnInfos) + // write parquet + writer := createTestWriter(config, t) + err := writer.WriteTableData(tableIR, tableIR, 0) + require.NoError(t, err) + + // read parquet and check answer + fr, err := local.NewLocalFileReader(path.Join(config.OutputDirPath, "test.employee.000000000.parquet")) + require.NoError(t, err) + pr, err := reader.NewParquetReader(fr, new(ParquetRow), 1) + require.NoError(t, err) + readRows := make([]ParquetRow, 1) + err = pr.Read(&readRows) + require.NoError(t, err) + readRow := readRows[0] + + for i, c := range allTypesColumnInfos { + names := strings.Split(c.Name, "_") + structName := "" + for j := 1; j < len(names); j++ { + structName += fmt.Sprintf("%s%s", strings.ToUpper(names[j][:1]), names[j][1:]) + } + switch c.Type { + case "CHAR", "VARCHAR", "TEXT", "TINYTEXT", "MEDIUMTEXT", "LONGTEXT", "DATE", "TIME", "JSON", "VECTOR", "ENUM", "SET": + value := reflect.ValueOf(readRow).FieldByName(structName).String() + require.Equal(t, data[0][i], value) + case "BLOB", "TINYBLOB", "MEDIUMBLOB", "LONGBLOB", "BINARY", "VARBINARY", "BIT": + value := reflect.ValueOf(readRow).FieldByName(structName).String() + require.Equal(t, data[0][i], []byte(value)) + case "TIMESTAMP", "DATETIME": + value := reflect.ValueOf(readRow).FieldByName(structName).Interface().(*int64) + require.Equal(t, data[0][i], time.UnixMicro(*value).UTC().Format(time.DateTime)) + case "TINYINT", "UNSIGNED TINYINT", "SMALLINT", "MEDIUMINT", "UNSIGNED SMALLINT", "UNSIGNED MEDIUMINT", "YEAR", "INT": + value := reflect.ValueOf(readRow).FieldByName(structName).Int() + require.Equal(t, data[0][i], strconv.FormatInt(value, 10)) + case "UNSIGNED INT", "BIGINT": + value := reflect.ValueOf(readRow).FieldByName(structName).Int() + require.Equal(t, data[0][i], strconv.FormatInt(value, 10)) + case "UNSIGNED BIGINT": + value := reflect.ValueOf(readRow).FieldByName(structName).String() + bigint := types.DECIMAL_BYTE_ARRAY_ToString([]byte(value), 20, 0) + require.Equal(t, data[0][i], bigint) + case "FLOAT", "DOUBLE": + value := reflect.ValueOf(readRow).FieldByName(structName).Float() + require.Equal(t, data[0][i], fmt.Sprintf("%.3f", value)) + case "DECIMAL": + if c.Precision == 9 || c.Precision == 18 { + value := reflect.ValueOf(readRow).FieldByName(structName).Int() + stringInt := strconv.FormatInt(value, 10) + result := stringInt[:len(stringInt)-int(c.Scale)] + "." + stringInt[len(stringInt)-int(c.Scale):] + require.Equal(t, data[0][i], result) + } else if c.Precision == 38 || c.Precision == 40 { + value := reflect.ValueOf(readRow).FieldByName(structName).String() + require.Equal(t, data[0][i], value) + } + } + } +} + +type ParquetRow struct { + Tinyint int32 `parquet:"name=t_tinyint, type=INT32"` + TinyintUnsigned int32 `parquet:"name=t_tinyint_unsigned, type=INT32"` + Smallint int32 `parquet:"name=t_smallint, type=INT32"` + SmallintUnsigned int32 `parquet:"name=t_smallint_unsigned, type=INT32"` + Mediumint int32 `parquet:"name=t_mediumint, type=INT32"` + MediumintUnsigned int32 `parquet:"name=t_mediumint_unsigned, type=INT32"` + Int int32 `parquet:"name=t_int, type=INT32"` + IntUnsigned int64 `parquet:"name=t_int_unsigned, type=INT64"` + Bigint int64 `parquet:"name=t_bigint, type=INT64"` + BigintUnsigned string `parquet:"name=t_bigint_unsigned, type=FIXED_LEN_BYTE_ARRAY,length=9"` + Float float64 `parquet:"name=t_float, type=FLOAT"` + Double float64 `parquet:"name=t_double, type=DOUBLE"` + Char string `parquet:"name=t_char, type=BYTE_ARRAY"` + Varchar string `parquet:"name=t_varchar, type=BYTE_ARRAY"` + Binary string `parquet:"name=t_binary, type=BYTE_ARRAY"` + Varbinary string `parquet:"name=t_varbinary, type=BYTE_ARRAY"` + Tinytext string `parquet:"name=t_tinytext, type=BYTE_ARRAY"` + Text string `parquet:"name=t_text, type=BYTE_ARRAY"` + Mediumtext string `parquet:"name=t_mediumtext, type=BYTE_ARRAY"` + Longtext string `parquet:"name=t_longtext, type=BYTE_ARRAY"` + Tinyblob string `parquet:"name=t_tinyblob, type=BYTE_ARRAY"` + Blob string `parquet:"name=t_blob, type=BYTE_ARRAY"` + Mediumblob string `parquet:"name=t_mediumblob, type=BYTE_ARRAY"` + Longblob string `parquet:"name=t_longblob, type=BYTE_ARRAY"` + Date string `parquet:"name=t_date, type=BYTE_ARRAY"` + Datetime *int64 `parquet:"name=t_datetime, type=INT64"` + Timestamp *int64 `parquet:"name=t_timestamp, type=INT64"` + Time string `parquet:"name=t_time, type=BYTE_ARRAY"` + Year int32 `parquet:"name=t_year, type=INT32"` + Enum string `parquet:"name=t_enum, type=BYTE_ARRAY"` + Set string `parquet:"name=t_set, type=BYTE_ARRAY"` + Bit string `parquet:"name=t_bit, type=BYTE_ARRAY"` + Json string `parquet:"name=t_json, type=BYTE_ARRAY"` + Decimal9 int32 `parquet:"name=t_decimal9, type=INT32"` + Decimal18 int64 `parquet:"name=t_decimal18, type=INT64"` + Decimal38 string `parquet:"name=t_decimal38, type=BYTE_ARRAY"` + Decimal40 string `parquet:"name=t_decimal40, type=BYTE_ARRAY"` + Vector string `parquet:"name=t_vector, type=BYTE_ARRAY"` +} + +var allTypesColumnInfos = []*ColumnInfo{ + { + Name: "t_tinyint", + Type: "TINYINT", + }, + { + Name: "t_tinyint_unsigned", + Type: "UNSIGNED TINYINT", + }, + { + Name: "t_smallint", + Type: "SMALLINT", + }, + { + Name: "t_smallint_unsigned", + Type: "UNSIGNED SMALLINT", + }, + { + Name: "t_mediumint", + Type: "MEDIUMINT", + }, + { + Name: "t_mediumint_unsigned", + Type: "UNSIGNED MEDIUMINT", + }, + { + Name: "t_int", + Type: "INT", + }, + { + Name: "t_int_unsigned", + Type: "UNSIGNED INT", + }, + { + Name: "t_bigint", + Type: "BIGINT", + }, + { + Name: "t_bigint_unsigned", + Type: "UNSIGNED BIGINT", + }, + { + Name: "t_float", + Type: "FLOAT", + }, + { + Name: "t_double", + Type: "DOUBLE", + }, + { + Name: "t_char", + Type: "CHAR", + }, + { + Name: "t_varchar", + Type: "VARCHAR", + }, + { + Name: "t_binary", + Type: "BINARY", + }, + { + Name: "t_varbinary", + Type: "VARBINARY", + }, + { + Name: "t_tinytext", + Type: "TINYTEXT", + }, + { + Name: "t_text", + Type: "TEXT", + }, + { + Name: "t_mediumtext", + Type: "MEDIUMTEXT", + }, + { + Name: "t_longtext", + Type: "LONGTEXT", + }, + { + Name: "t_tinyblob", + Type: "TINYBLOB", + }, + { + Name: "t_blob", + Type: "BLOB", + }, + { + Name: "t_mediumblob", + Type: "MEDIUMBLOB", + }, + { + Name: "t_longblob", + Type: "LONGBLOB", + }, + { + Name: "t_date", + Type: "DATE", + }, + { + Name: "t_datetime", + Type: "DATETIME", + }, + { + Name: "t_timestamp", + Type: "TIMESTAMP", + }, + { + Name: "t_time", + Type: "TIME", + }, + { + Name: "t_year", + Type: "YEAR", + }, + { + Name: "t_enum", + Type: "ENUM", + }, + { + Name: "t_set", + Type: "SET", + }, + { + Name: "t_bit", + Type: "BIT", + }, + { + Name: "t_json", + Type: "JSON", + }, + { + Name: "t_decimal9", + Type: "DECIMAL", + Precision: 9, + Scale: 1, + }, + { + Name: "t_decimal18", + Type: "DECIMAL", + Precision: 18, + Scale: 1, + }, + { + Name: "t_decimal38", + Type: "DECIMAL", + Precision: 38, + Scale: 1, + }, + { + Name: "t_decimal40", + Type: "DECIMAL", + Precision: 40, + Scale: 1, + }, + { + Name: "t_vector", + Type: "VECTOR", + }, +} + +func TestWriteParquetWithInvalidDate(t *testing.T) { + dir := t.TempDir() + config := defaultConfigForTest(t) + config.FileType = "parquet" + config.OutputDirPath = dir + config.ParquetPageSize = 1024 * 1024 + config.ParquetCompressType = NoCompression + + columnInfos := []*ColumnInfo{ + { + Name: "t_datetime", + Type: "DATETIME", + }, + { + Name: "t_timestamp", + Type: "TIMESTAMP", + }, + } + // mock data + mockDatas := [][]string{ + {"1971-00-00 00:00:00", "0000-00-00 00:00:00"}, // invalid date + {"0000-00-00 00:00:00", "0000-00-00 00:00:00"}, // zero date + } + data := make([][]driver.Value, len(mockDatas)) + for i, mockData := range mockDatas { + row := make([]driver.Value, 0, len(columnInfos)) + for _, mock := range mockData { + row = append(row, mock) + } + data[i] = row + } + // parquet need more information to write, so we need to provide columnInfos + tableIR := newMockTableIRWithColumnInfo("test", "date", data, nil, columnInfos) + // write parquet + writer := createTestWriter(config, t) + err := writer.WriteTableData(tableIR, tableIR, 0) + require.NoError(t, err) + + // read parquet and check answer + type dateRow struct { + Datetime *int64 `parquet:"name=t_datetime, type=INT64"` + Timestamp *int64 `parquet:"name=t_timestamp, type=INT64"` + } + + fr, err := local.NewLocalFileReader(path.Join(config.OutputDirPath, "test.date.000000000.parquet")) + require.NoError(t, err) + pr, err := reader.NewParquetReader(fr, new(dateRow), 1) + require.NoError(t, err) + readRows := make([]dateRow, len(mockDatas)) + err = pr.Read(&readRows) + require.NoError(t, err) + + exceptedResult := [][]*int64{ + {nil, nil}, + {nil, nil}, + } + + for i, row := range readRows { + for j, column := range columnInfos { + if column.Type == "DATETIME" { + require.Equal(t, exceptedResult[i][j], row.Datetime) + } else if column.Type == "TIMESTAMP" { + require.Equal(t, exceptedResult[i][j], row.Timestamp) + } + } + } +} diff --git a/dumpling/export/writer_util.go b/dumpling/export/writer_util.go index 7171bbee7e1d9..18ad160da9c74 100644 --- a/dumpling/export/writer_util.go +++ b/dumpling/export/writer_util.go @@ -599,6 +599,8 @@ const ( FileFormatSQLText // FileFormatCSV indicates the given file type is csv type FileFormatCSV + // FileFormatParquet indicates the given file type is parquet type + FileFormatParquet ) const ( @@ -606,6 +608,8 @@ const ( FileFormatSQLTextString = "sql" // FileFormatCSVString indicates the string/suffix of csv type file FileFormatCSVString = "csv" + // FileFormatParquetString indicates the string/suffix of parquet type file + FileFormatParquetString = "parquet" ) // String implement Stringer.String method. @@ -615,6 +619,8 @@ func (f FileFormat) String() string { return strings.ToUpper(FileFormatSQLTextString) case FileFormatCSV: return strings.ToUpper(FileFormatCSVString) + case FileFormatParquet: + return strings.ToUpper(FileFormatParquetString) default: return "unknown" } @@ -630,6 +636,8 @@ func (f FileFormat) Extension() string { return FileFormatSQLTextString case FileFormatCSV: return FileFormatCSVString + case FileFormatParquet: + return FileFormatParquetString default: return "unknown_format" } @@ -649,6 +657,8 @@ func (f FileFormat) WriteInsert( return WriteInsert(pCtx, cfg, meta, tblIR, w, metrics) case FileFormatCSV: return WriteInsertInCsv(pCtx, cfg, meta, tblIR, w, metrics) + case FileFormatParquet: + return WriteInsertInParquet(pCtx, cfg, meta, tblIR, w, metrics) default: return 0, errors.Errorf("unknown file format") }