diff --git a/.gitignore b/.gitignore index 8960131..8799798 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +.idea dist/ grabit grabit.lock diff --git a/cmd/add.go b/cmd/add.go index 28c3444..51fa9e0 100644 --- a/cmd/add.go +++ b/cmd/add.go @@ -42,13 +42,15 @@ func runAdd(cmd *cobra.Command, args []string) error { if err != nil { return err } - err = lock.AddResource(args, algo, tags, filename) + dynamic, err := cmd.Flags().GetBool("dynamic") if err != nil { return err } - err = lock.Save() + + err = lock.AddResource(args, algo, tags, filename, dynamic) if err != nil { return err } - return nil + + return lock.Save() } diff --git a/cmd/download.go b/cmd/download.go index 2e29611..36b3541 100644 --- a/cmd/download.go +++ b/cmd/download.go @@ -5,6 +5,8 @@ package cmd import ( "github.com/cisco-open/grabit/internal" + "github.com/rs/zerolog" + "github.com/rs/zerolog/log" "github.com/spf13/cobra" ) @@ -19,10 +21,20 @@ func addDownload(cmd *cobra.Command) { downloadCmd.Flags().StringArray("tag", []string{}, "Only download the resources with the given tag") downloadCmd.Flags().StringArray("notag", []string{}, "Only download the resources without the given tag") downloadCmd.Flags().String("perm", "", "Optional permissions for the downloaded files (e.g. '644')") + downloadCmd.Flags().BoolP("verbose", "v", false, "Enable verbose output") cmd.AddCommand(downloadCmd) } func runFetch(cmd *cobra.Command, args []string) error { + logLevel, _ := cmd.Flags().GetString("log-level") + level, _ := zerolog.ParseLevel(logLevel) + zerolog.SetGlobalLevel(level) + + if level <= zerolog.DebugLevel { + log.Debug().Msg("Starting download") + // Add more debug logs as needed + } + lockFile, err := cmd.Flags().GetString("lock-file") if err != nil { return err @@ -47,9 +59,21 @@ func runFetch(cmd *cobra.Command, args []string) error { if err != nil { return err } - err = lock.Download(dir, tags, notags, perm) + + d := cmd.Context().Value("downloader").(*downloader.Downloader) + + if verbose { + log.Debug().Str("lockFile", lockFile).Str("dir", dir).Strs("tags", tags).Strs("notags", notags).Str("perm", perm).Msg("Starting download") + } + + err = lock.Download(dir, tags, notags, perm, d) if err != nil { return err } + + if verbose { + log.Debug().Msg("Download completed successfully") + } + return nil } diff --git a/cmd/dowload_test.go b/cmd/download_test.go similarity index 100% rename from cmd/dowload_test.go rename to cmd/download_test.go diff --git a/cmd/root.go b/cmd/root.go index 854dab2..55ddd35 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -4,6 +4,7 @@ package cmd import ( + "context" "os" "path/filepath" "strings" @@ -13,6 +14,8 @@ import ( "github.com/spf13/cobra" ) +var verbose bool + func NewRootCmd() *cobra.Command { cmd := &cobra.Command{ Use: "grabit", @@ -21,10 +24,14 @@ func NewRootCmd() *cobra.Command { } cmd.PersistentFlags().StringP("lock-file", "f", filepath.Join(getPwd(), GRAB_LOCK), "lockfile path (default: $PWD/grabit.lock") cmd.PersistentFlags().StringP("log-level", "l", "info", "log level (trace, debug, info, warn, error, fatal)") + cmd.PersistentFlags().BoolVarP(&verbose, "verbose", "v", false, "Enable verbose output") + addDelete(cmd) addDownload(cmd) addAdd(cmd) addVersion(cmd) + AddUpdate(cmd) + AddVerify(cmd) return cmd } @@ -59,6 +66,11 @@ func initLog(ll string) { } func Execute(rootCmd *cobra.Command) { + rootCmd.PersistentPreRun = func(cmd *cobra.Command, args []string) { + ctx := context.WithValue(cmd.Context(), "downloader", d) + cmd.SetContext(ctx) + } + ll, err := rootCmd.PersistentFlags().GetString("log-level") if err != nil { log.Fatal().Msg(err.Error()) diff --git a/cmd/root_test.go b/cmd/root_test.go index 25d86a6..9cb7723 100644 --- a/cmd/root_test.go +++ b/cmd/root_test.go @@ -3,6 +3,7 @@ package cmd import ( "bytes" "testing" + "time" "github.com/stretchr/testify/assert" ) @@ -11,6 +12,7 @@ func TestRunRoot(t *testing.T) { rootCmd := NewRootCmd() buf := new(bytes.Buffer) rootCmd.SetOutput(buf) + d := downloader.NewDownloader(10 * time.Second) // Create a new downloader with a 10-second timeout Execute(rootCmd) assert.Contains(t, buf.String(), "and verifies their integrity") } diff --git a/cmd/update.go b/cmd/update.go new file mode 100644 index 0000000..c75d525 --- /dev/null +++ b/cmd/update.go @@ -0,0 +1,24 @@ +package cmd + +import ( + "github.com/cisco-open/grabit/internal" + "github.com/spf13/cobra" +) + +var updateCmd = &cobra.Command{ + Use: "update [URL]", + Short: "Update a resource", + Args: cobra.ExactArgs(1), + RunE: func(cmd *cobra.Command, args []string) error { + lockFile, _ := cmd.Flags().GetString("lock-file") + lock, err := internal.NewLock(lockFile, false) + if err != nil { + return err + } + return lock.UpdateResource(args[0]) + }, +} + +func AddUpdate(cmd *cobra.Command) { + cmd.AddCommand(updateCmd) +} diff --git a/cmd/verify.go b/cmd/verify.go new file mode 100644 index 0000000..7301983 --- /dev/null +++ b/cmd/verify.go @@ -0,0 +1,23 @@ +package cmd + +import ( + "github.com/cisco-open/grabit/internal" + "github.com/spf13/cobra" +) + +var verifyCmd = &cobra.Command{ + Use: "verify", + Short: "Verify the integrity of downloaded resources", + RunE: func(cmd *cobra.Command, args []string) error { + lockFile, _ := cmd.Flags().GetString("lock-file") + lock, err := internal.NewLock(lockFile, false) + if err != nil { + return err + } + return lock.VerifyIntegrity() + }, +} + +func AddVerify(cmd *cobra.Command) { + cmd.AddCommand(verifyCmd) +} diff --git a/internal/lock.go b/internal/lock.go index ae6707b..d503b13 100644 --- a/internal/lock.go +++ b/internal/lock.go @@ -5,12 +5,10 @@ package internal import ( "bufio" - "context" "errors" "fmt" "os" "strconv" - toml "github.com/pelletier/go-toml/v2" ) @@ -26,6 +24,12 @@ type config struct { Resource []Resource } +type resource struct { + Urls []string + Integrity string + Tags []string +} + func NewLock(path string, newOk bool) (*Lock, error) { _, error := os.Stat(path) if os.IsNotExist(error) { @@ -49,13 +53,13 @@ func NewLock(path string, newOk bool) (*Lock, error) { return &Lock{path: path, conf: conf}, nil } -func (l *Lock) AddResource(paths []string, algo string, tags []string, filename string) error { +func (l *Lock) AddResource(paths []string, algo string, tags []string, filename string, dynamic bool) error { for _, u := range paths { if l.Contains(u) { return fmt.Errorf("resource '%s' is already present", u) } } - r, err := NewResourceFromUrl(paths, algo, tags, filename) + r, err := NewResourceFromUrl(paths, algo, tags, filename, dynamic) if err != nil { return err } @@ -89,89 +93,110 @@ func strToFileMode(perm string) (os.FileMode, error) { // Download gets all the resources in this lock file and moves them to // the destination directory. -func (l *Lock) Download(dir string, tags []string, notags []string, perm string) error { - if stat, err := os.Stat(dir); err != nil || !stat.IsDir() { - return fmt.Errorf("'%s' is not a directory", dir) - } - mode, err := strToFileMode(perm) - if err != nil { - return fmt.Errorf("'%s' is not a valid permission definition", perm) - } +func (l *Lock) Download(dir string, tags []string, notags []string, perm string, ctx context.Context) error { + if stat, err := os.Stat(dir); err != nil || !stat.IsDir() { + return fmt.Errorf("'%s' is not a directory", dir) + } + mode, err := strToFileMode(perm) + if err != nil { + return fmt.Errorf("'%s' is not a valid permission definition", perm) + } + + filteredResources := l.filterResources(tags, notags) + + total := len(filteredResources) + if total == 0 { + return fmt.Errorf("nothing to download") + } + + errs := make([]error, 0) + for _, r := range filteredResources { + err := r.Download(dir, mode, ctx) + if err != nil { + errs = append(errs, fmt.Errorf("failed to download %s: %w", r.Urls[0], err)) + } + } + + if len(errs) > 0 { + return errors.Join(errs...) + } + return nil +} - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - // Filter in the resources that have all the required tags. - tagFilteredResources := []Resource{} + +func (l *Lock) filterResources(tags []string, notags []string) []Resource { + tagFilteredResources := l.conf.Resource if len(tags) > 0 { + tagFilteredResources = []Resource{} for _, r := range l.conf.Resource { - hasAllTags := true - for _, tag := range tags { - hasTag := false - for _, rtag := range r.Tags { - if tag == rtag { - hasTag = true - break - } - } - if !hasTag { - hasAllTags = false - break - } - } - if hasAllTags { + if r.hasAllTags(tags) { tagFilteredResources = append(tagFilteredResources, r) } } - } else { - tagFilteredResources = l.conf.Resource } - // Filter out the resources that have any 'notag' tag. - filteredResources := []Resource{} + + filteredResources := tagFilteredResources if len(notags) > 0 { + filteredResources = []Resource{} for _, r := range tagFilteredResources { - hasTag := false - for _, notag := range notags { - for _, rtag := range r.Tags { - if notag == rtag { - hasTag = true - } - } - } - if !hasTag { + if !r.hasAnyTag(notags) { filteredResources = append(filteredResources, r) } } - } else { - filteredResources = tagFilteredResources - } - - total := len(filteredResources) - if total == 0 { - return fmt.Errorf("nothing to download") - } - errorCh := make(chan error, total) - for _, r := range filteredResources { - resource := r - go func() { - err := resource.Download(dir, mode, ctx) - errorCh <- err - }() - } - done := 0 - errs := []error{} - for range total { - err = <-errorCh - if err != nil { - errs = append(errs, err) - } else { - done += 1 + } + + return filteredResources +} + +func (r *Resource) hasAllTags(tags []string) bool { + for _, tag := range tags { + if !r.hasTag(tag) { + return false + } + } + return true +} + +func (r *Resource) hasAnyTag(tags []string) bool { + for _, tag := range tags { + if r.hasTag(tag) { + return true } } - if done == total { - return nil + return false +} + +func (r *Resource) hasTag(tag string) bool { + for _, rtag := range r.Tags { + if tag == rtag { + return true + } + } + return false +} + +func (l *Lock) UpdateResource(url string) error { + for i, r := range l.conf.Resource { + if r.Contains(url) { + newResource, err := NewResourceFromUrl(r.Urls, r.Integrity, r.Tags, r.Filename, r.Dynamic) + if err != nil { + return err + } + l.conf.Resource[i] = *newResource + return l.Save() + } } - if len(errs) > 0 { - return errors.Join(errs...) + return fmt.Errorf("resource with URL '%s' not found", url) +} + +func (l *Lock) VerifyIntegrity() error { + for _, r := range l.conf.Resource { + for _, url := range r.Urls { + err := checkIntegrityFromUrl(url, r.Integrity) + if err != nil { + return fmt.Errorf("integrity check failed for %s: %w", url, err) + } + } } return nil } @@ -207,4 +232,5 @@ func (l *Lock) Contains(url string) bool { } } return false + } diff --git a/internal/lock_test.go b/internal/lock_test.go index eb347ed..ec2ebf4 100644 --- a/internal/lock_test.go +++ b/internal/lock_test.go @@ -9,6 +9,7 @@ import ( "os" "path/filepath" "testing" + "time" "github.com/cisco-open/grabit/test" "github.com/stretchr/testify/assert" @@ -51,7 +52,7 @@ func TestLockManipulations(t *testing.T) { port, server := test.HttpHandler(handler) defer server.Close() resource := fmt.Sprintf("http://localhost:%d/test2.html", port) - err = lock.AddResource([]string{resource}, "sha512", []string{}, "") + err = lock.AddResource([]string{resource}, "sha512", []string{}, "", false) assert.Nil(t, err) assert.Equal(t, 2, len(lock.conf.Resource)) err = lock.Save() @@ -63,12 +64,12 @@ func TestLockManipulations(t *testing.T) { func TestDuplicateResource(t *testing.T) { url := "http://localhost:123456/test.html" path := test.TmpFile(t, fmt.Sprintf(` - [[Resource]] - Urls = ['%s'] - Integrity = 'sha256-asdasdasd'`, url)) + [[Resource]] + Urls = ['%s'] + Integrity = 'sha256-asdasdasd'`, url)) lock, err := NewLock(path, false) assert.Nil(t, err) - err = lock.AddResource([]string{url}, "sha512", []string{}, "") + err = lock.AddResource([]string{url}, "sha512", []string{}, "", false) assert.NotNil(t, err) assert.Contains(t, err.Error(), "already present") } @@ -115,7 +116,7 @@ func TestDownload(t *testing.T) { lock, err := NewLock(path, false) assert.Nil(t, err) dir := test.TmpDir(t) - err = lock.Download(dir, []string{}, []string{}, perm) + err = lock.Download(dir, []string{}, []string{}, perm, downloader.NewDownloader(10*time.Second)) if err != nil { t.Fatal(err) } @@ -132,3 +133,4 @@ func TestDownload(t *testing.T) { } assert.Equal(t, stats.Mode().Perm().String(), strPerm) } + diff --git a/internal/resource.go b/internal/resource.go index 4e5f05f..6072d3b 100644 --- a/internal/resource.go +++ b/internal/resource.go @@ -8,6 +8,8 @@ import ( "crypto/sha256" "encoding/hex" "fmt" + "io" + "net/http" "net/url" "os" "path" @@ -24,12 +26,14 @@ type Resource struct { Integrity string Tags []string `toml:",omitempty"` Filename string `toml:",omitempty"` + Dynamic bool `toml:",omitempty"` } -func NewResourceFromUrl(urls []string, algo string, tags []string, filename string) (*Resource, error) { +func NewResourceFromUrl(urls []string, algo string, tags []string, filename string, dynamic bool) (*Resource, error) { if len(urls) < 1 { return nil, fmt.Errorf("empty url list") } + url := urls[0] ctx := context.Background() path, err := GetUrltoTempFile(url, ctx) @@ -63,6 +67,21 @@ func getUrl(u string, fileName string, ctx context.Context) (string, error) { return fileName, nil } +func checkIntegrityFromUrl(url string, expectedIntegrity string) error { + tempFile, err := GetUrltoTempFile(url, context.Background()) + if err != nil { + return err + } + defer os.Remove(tempFile) + + algo, err := getAlgoFromIntegrity(expectedIntegrity) + if err != nil { + return err + } + + return checkIntegrityFromFile(tempFile, algo, expectedIntegrity, url) +} + // GetUrlToDir downloads the given resource to the given directory and returns the path to it. func GetUrlToDir(u string, targetDir string, ctx context.Context) (string, error) { // create temporary name in the target directory. @@ -83,55 +102,27 @@ func GetUrltoTempFile(u string, ctx context.Context) (string, error) { } func (l *Resource) Download(dir string, mode os.FileMode, ctx context.Context) error { - ok := false - algo, err := getAlgoFromIntegrity(l.Integrity) - if err != nil { - return err - } - var downloadError error = nil for _, u := range l.Urls { - // Download file in the target directory so that the call to - // os.Rename is atomic. - lpath, err := GetUrlToDir(u, dir, ctx) + err := l.DownloadFile(u, dir) if err != nil { - downloadError = err - break - } - err = checkIntegrityFromFile(lpath, algo, l.Integrity, u) - if err != nil { - return err + continue } - localName := "" - if l.Filename != "" { - localName = l.Filename - } else { + localName := l.Filename + if localName == "" { localName = path.Base(u) } resPath := filepath.Join(dir, localName) - err = os.Rename(lpath, resPath) - if err != nil { - return err - } + if mode != NoFileMode { err = os.Chmod(resPath, mode.Perm()) if err != nil { return err } } - ok = true - } - if !ok { - if err == nil { - if downloadError != nil { - return downloadError - } else { - panic("no error but no file downloaded") - } - } - return err + return nil } - return nil + return fmt.Errorf("failed to download resource from any URL") } func (l *Resource) Contains(url string) bool { @@ -142,3 +133,57 @@ func (l *Resource) Contains(url string) bool { } return false } +func calculateFileHash(filePath string) (string, error) { + file, err := os.Open(filePath) + if err != nil { + return "", err + } + defer file.Close() + + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + + return hex.EncodeToString(hash.Sum(nil)), nil +} + +func (l *Resource) DownloadFile(url, targetDir string) error { + fileName := filepath.Base(url) + targetPath := filepath.Join(targetDir, fileName) + + if _, err := os.Stat(targetPath); err == nil { + fileHash, err := calculateFileHash(targetPath) + if err == nil && fileHash == l.Integrity { + log.Debug().Str("File", fileName).Msg("File already exists with correct hash. Skipping download.") + return nil + } + } + + resp, err := http.Get(url) + if err != nil { + return err + } + defer resp.Body.Close() + + out, err := os.Create(targetPath) + if err != nil { + return err + } + defer out.Close() + + _, err = io.Copy(out, resp.Body) + if err != nil { + return err + } + + downloadedHash, err := calculateFileHash(targetPath) + if err != nil { + return err + } + if downloadedHash != l.Integrity { + return fmt.Errorf("hash mismatch: expected %s, got %s", l.Integrity, downloadedHash) + } + + return nil +} diff --git a/internal/resource_test.go b/internal/resource_test.go index 80a11e9..84969c2 100644 --- a/internal/resource_test.go +++ b/internal/resource_test.go @@ -4,12 +4,13 @@ package internal import ( + "context" "fmt" - "net/http" - "testing" - "github.com/cisco-open/grabit/test" "github.com/stretchr/testify/assert" + "net/http" + "testing" + "time" ) func TestNewResourceFromUrl(t *testing.T) { @@ -41,7 +42,7 @@ func TestNewResourceFromUrl(t *testing.T) { } for _, data := range tests { - resource, err := NewResourceFromUrl(data.urls, algo, []string{}, "") + resource, err := NewResourceFromUrl(data.urls, "sha256", []string{}, "", false) assert.Equal(t, data.valid, err == nil) if err != nil { assert.Contains(t, err.Error(), data.errorContains) @@ -50,3 +51,24 @@ func TestNewResourceFromUrl(t *testing.T) { } } } +func TestDynamicResourceDownload(t *testing.T) { + handler := func(w http.ResponseWriter, r *http.Request) { + w.Write([]byte(time.Now().String())) + } + port, server := test.HttpHandler(handler) + defer server.Close() + + url := fmt.Sprintf("http://localhost:%d/dynamic", port) + resource := &Resource{ + Urls: []string{url}, + Dynamic: true, + } + + dir := t.TempDir() + err := resource.Download(dir, 0644, context.Background()) + assert.NoError(t, err) + + // Download again to ensure it doesn't fail due to content change + err = resource.Download(dir, 0644, context.Background()) + assert.NoError(t, err) +}