From 9237a8300c22a226a23b5548f95f210a2e23ddf3 Mon Sep 17 00:00:00 2001 From: Stefano Scafiti Date: Wed, 24 Jul 2024 15:09:26 +0200 Subject: [PATCH] Implement CSV import/export of tables Signed-off-by: Stefano Scafiti --- cmd/immuadmin/command/commandline.go | 1 + cmd/immuadmin/command/database.go | 236 +++++++++++++++++++++++++++ pkg/api/schema/sql.go | 4 +- pkg/api/schema/sql_test.go | 2 +- 4 files changed, 240 insertions(+), 3 deletions(-) diff --git a/cmd/immuadmin/command/commandline.go b/cmd/immuadmin/command/commandline.go index adab90fb63..0f86965a6f 100644 --- a/cmd/immuadmin/command/commandline.go +++ b/cmd/immuadmin/command/commandline.go @@ -96,6 +96,7 @@ func (cl *commandline) Register(rootCmd *cobra.Command) *cobra.Command { cl.stats(rootCmd) cl.serverConfig(rootCmd) cl.database(rootCmd) + return rootCmd } diff --git a/cmd/immuadmin/command/database.go b/cmd/immuadmin/command/database.go index 2e5ca1a272..53305dbcc9 100644 --- a/cmd/immuadmin/command/database.go +++ b/cmd/immuadmin/command/database.go @@ -17,16 +17,22 @@ limitations under the License. package immuadmin import ( + "encoding/csv" "fmt" + "io" + "os" + "path" "strconv" "strings" "time" "github.com/codenotary/immudb/cmd/helper" c "github.com/codenotary/immudb/cmd/helper" + "github.com/codenotary/immudb/embedded/sql" "github.com/codenotary/immudb/embedded/store" "github.com/codenotary/immudb/embedded/tbtree" "github.com/codenotary/immudb/pkg/api/schema" + "github.com/codenotary/immudb/pkg/client" "github.com/codenotary/immudb/pkg/database" "github.com/codenotary/immudb/pkg/replication" "github.com/spf13/cobra" @@ -386,10 +392,240 @@ func (cl *commandline) database(cmd *cobra.Command) { dbCmd.AddCommand(flushCmd) dbCmd.AddCommand(compactCmd) dbCmd.AddCommand(truncateCmd) + dbCmd.AddCommand(cl.createExportCmd()) + dbCmd.AddCommand(cl.createImportCmd()) cmd.AddCommand(dbCmd) } +func (cl *commandline) createExportCmd() *cobra.Command { + exportCmd := &cobra.Command{ + Use: "export", + Short: "Dump an SQL table to a CSV file", + Aliases: []string{"e"}, + ArgAliases: []string{"table"}, + PersistentPreRunE: cl.ConfigChain(cl.connect), + PersistentPostRun: cl.disconnect, + RunE: func(cmd *cobra.Command, args []string) error { + table := args[0] + + outputPath, _ := cmd.Flags().GetString("o") + if outputPath == "" { + wd, err := os.Getwd() + if err != nil { + return err + } + outputPath = path.Join(wd, table) + ".csv" + } + + reader, err := cl.immuClient.SQLQueryReader(cl.context, fmt.Sprintf("SELECT * FROM %s", table), nil) + if err != nil { + return err + } + defer reader.Close() + + csvFile, err := os.Create(outputPath) + if err != nil { + return err + } + defer csvFile.Close() + + sep, err := cmd.Flags().GetString("s") + if err != nil { + return err + } + if len(sep) != 1 { + return fmt.Errorf("invalid separator") + } + + writer := csv.NewWriter(csvFile) + writer.Comma = rune(sep[0]) + writer.UseCRLF = true + defer writer.Flush() + + cols := reader.Columns() + + colNames := make([]string, len(cols)) + for i, col := range cols { + colNames[i] = formatColName(col.Name) + } + + if err := writer.Write(colNames); err != nil { + return err + } + + out := make([]string, len(cols)) + for reader.Next() { + row, err := reader.Read() + if err != nil { + return err + } + + if err := rowToCSV(row, cols, out); err != nil { + return err + } + + if err := writer.Write(out); err != nil { + return err + } + } + return writer.Error() + }, + Args: cobra.ExactArgs(1), + } + exportCmd.Flags().String("o", "", "output") + exportCmd.Flags().String("s", ",", "separator") + + return exportCmd +} + +func rowToCSV(row client.Row, cols []client.Column, out []string) error { + for i, v := range row { + colType := cols[i].Type + rv, err := renderValue(v, colType) + if err != nil { + return err + } + out[i] = rv + } + return nil +} + +func renderValue(v interface{}, colType string) (string, error) { + switch colType { + case sql.VarcharType, sql.JSONType, sql.UUIDType: + s, isStr := v.(string) + if !isStr { + return "", fmt.Errorf("invalid value received") + } + return s, nil + default: + sqlVal, err := schema.AsSQLValue(v) + if err != nil { + return "", err + } + return schema.RenderValue(sqlVal.Value), nil + } +} + +func (cl *commandline) createImportCmd() *cobra.Command { + importCmd := &cobra.Command{ + Use: "import", + Short: "Insert data to an existing table from a csv file", + Aliases: []string{"i"}, + ArgAliases: []string{"file"}, + PersistentPreRunE: cl.ConfigChain(cl.connect), + PersistentPostRun: cl.disconnect, + RunE: func(cmd *cobra.Command, args []string) error { + inputPath := args[0] + + csvFile, err := os.Open(inputPath) + if err != nil { + return err + } + defer csvFile.Close() + + sep, err := cmd.Flags().GetString("s") + if err != nil { + return err + } + if len(sep) != 1 { + return fmt.Errorf("invalid separator") + } + + reader := csv.NewReader(csvFile) + reader.Comma = rune(sep[0]) + reader.ReuseRecord = true + + hasHeader, err := cmd.Flags().GetBool("h") + if err != nil { + return err + } + + table, err := cmd.Flags().GetString("t") + if err != nil { + return err + } + if table == "" { + return fmt.Errorf("table name not specified") + } + + if hasHeader { + _, err := reader.Read() + if err != nil && err != io.EOF { + return nil + } + } + + // fetch column information + res, err := cl.immuClient.SQLQuery(cl.context, fmt.Sprintf("SELECT * FROM %s WHERE 1 = 0", table), nil, false) + if err != nil { + return err + } + + cols := make([]string, len(res.Columns)) + for i, col := range res.Columns { + cols[i] = formatColName(col.Name) + } + + row, err := reader.Read() + for err == nil { + if len(row) != len(cols) { + return fmt.Errorf("wrong number of columns") + } + + for i, v := range row { + row[i] = formatInsertValue(v, res.Columns[i].Type) + } + + _, err = cl.immuClient.SQLExec( + cl.context, + fmt.Sprintf("INSERT INTO %s(%s) VALUES (%s)", table, strings.Join(cols, ","), strings.Join(row, ",")), + nil, + ) + if err != nil { + return err + } + row, err = reader.Read() + } + if err != io.EOF { + return err + } + return nil + }, + Args: cobra.ExactArgs(1), + } + importCmd.Flags().String("t", "", "table") + importCmd.Flags().Bool("h", true, "interpret the first column as header") + importCmd.Flags().String("s", ",", "separator") + + return importCmd +} + +func formatColName(col string) string { + idx := strings.Index(col, ".") + if idx >= 0 { + return col[idx+1 : len(col)-1] + } + return col +} + +func formatInsertValue(v string, colType string) string { + if v == "NULL" { + return v + } + + switch colType { + case sql.VarcharType: + return fmt.Sprintf("'%s'", v) + case sql.TimestampType, sql.JSONType, sql.UUIDType: + return fmt.Sprintf("CAST ('%s' AS %s)", v, colType) + case sql.BLOBType: + return fmt.Sprintf("x'%s'", v) + } + return v +} + func prepareDatabaseNullableSettings(flags *pflag.FlagSet) (*schema.DatabaseNullableSettings, error) { var err error diff --git a/pkg/api/schema/sql.go b/pkg/api/schema/sql.go index 74999b3189..f2c13b4d7f 100644 --- a/pkg/api/schema/sql.go +++ b/pkg/api/schema/sql.go @@ -32,7 +32,7 @@ func EncodeParams(params map[string]interface{}) ([]*NamedParam, error) { i := 0 for n, v := range params { - sqlVal, err := asSQLValue(v) + sqlVal, err := AsSQLValue(v) if err != nil { return nil, err } @@ -52,7 +52,7 @@ func NamedParamsFromProto(protoParams []*NamedParam) map[string]interface{} { return params } -func asSQLValue(v interface{}) (*SQLValue, error) { +func AsSQLValue(v interface{}) (*SQLValue, error) { if v == nil { return &SQLValue{Value: &SQLValue_Null{}}, nil } diff --git a/pkg/api/schema/sql_test.go b/pkg/api/schema/sql_test.go index 06aac50e33..b2b05f51b5 100644 --- a/pkg/api/schema/sql_test.go +++ b/pkg/api/schema/sql_test.go @@ -106,7 +106,7 @@ func TestAsSQLValue(t *testing.T) { }, } { t.Run(d.n, func(t *testing.T) { - sqlVal, err := asSQLValue(d.val) + sqlVal, err := AsSQLValue(d.val) require.EqualValues(t, d.sqlVal, sqlVal) if d.isErr { require.ErrorIs(t, err, sql.ErrInvalidValue)