package main import ( "context" "errors" "fmt" "log" "os" "path/filepath" "sort" "strings" "time" "github.com/nats-io/nats.go" "github.com/nats-io/nats.go/jetstream" "github.com/spf13/cobra" "github.com/spf13/pflag" "github.com/spf13/viper" "golang.org/x/mod/semver" ) type Config struct { NatsURL string `mapstructure:"nats"` BucketName string `mapstructure:"bucket"` Directory string `mapstructure:"dir"` Prefix string `mapstructure:"prefix"` BinaryName string `mapstructure:"binary"` NotifyTopic string `mapstructure:"notify"` SkipNotify bool `mapstructure:"skip-notify"` Cleanup int `mapstructure:"cleanup"` CleanupAll bool `mapstructure:"cleanup-all"` } var rootCmd = &cobra.Command{ Use: "nats-upload", Short: "Upload binaries to NATS object store and cleanup old versions", RunE: func(cmd *cobra.Command, args []string) error { var cfg Config if err := viper.Unmarshal(&cfg); err != nil { return fmt.Errorf("failed to unmarshal config: %w", err) } if cfg.Directory == "" && cfg.Cleanup == 0 { return errors.New("directory path is required or cleanup must be enabled") } return runUploadAndCleanup(cmd.Context(), &cfg) }, } var cleanCmd = &cobra.Command{ Use: "clean", Short: "Cleanup old versions in NATS object store", RunE: func(cmd *cobra.Command, args []string) error { var cfg Config if err := viper.Unmarshal(&cfg); err != nil { return fmt.Errorf("failed to unmarshal config: %w", err) } if cfg.Cleanup == 0 { return errors.New("cleanup count must be greater than 0") } return runCleanupOnly(cmd.Context(), &cfg) }, } func init() { cobra.OnInitialize(initConfig) rootCmd.PersistentFlags().String("nats", "nats://localhost:4222", "NATS server URL") rootCmd.PersistentFlags().String("bucket", "binaries", "Object store bucket name") rootCmd.PersistentFlags().String("binary", "", "Binary name (defaults to first binary found)") rootCmd.PersistentFlags().Int("cleanup", 2, "Keep only N most recent versions (0 disables cleanup)") rootCmd.PersistentFlags().Bool("cleanup-all", false, "Cleanup all binaries, not just current one") rootCmd.PersistentFlags().Bool("clean-all", false, "Alias for --cleanup-all") rootCmd.Flags().String("dir", "upload", "Directory containing binaries to upload") rootCmd.Flags().String("prefix", "", "Prefix to strip from paths (like 'upload/')") rootCmd.Flags().String("notify", "binaries.update", "NATS topic to publish update notification") rootCmd.Flags().Bool("skip-notify", false, "Skip publishing update notification") } func bindPFlag(fs *pflag.FlagSet, key string, flagNames ...string) { name := key if len(flagNames) > 0 { name = flagNames[0] } if err := viper.BindPFlag(key, fs.Lookup(name)); err != nil { log.Fatalf("error binding %s flag: %v", key, err) } } func init() { rootPersistentFlags := rootCmd.PersistentFlags() for _, name := range []string{"nats", "bucket", "binary", "cleanup", "cleanup-all"} { bindPFlag(rootPersistentFlags, name) } rootFlags := rootCmd.Flags() for _, name := range []string{"dir", "prefix", "notify", "skip-notify"} { bindPFlag(rootFlags, name) } rootCmd.AddCommand(cleanCmd) } func initConfig() { viper.SetEnvPrefix("INPUT") viper.AutomaticEnv() viper.RegisterAlias("nats_url", "nats") viper.RegisterAlias("source", "dir") viper.RegisterAlias("strip_prefix", "prefix") viper.RegisterAlias("notify_topic", "notify") viper.RegisterAlias("clean_all", "cleanup-all") } type NATSClient struct { Conn *nats.Conn JS jetstream.JetStream Store jetstream.ObjectStore } func getNATSConnection(ctx context.Context, cfg *Config) (*NATSClient, error) { nc, err := nats.Connect(cfg.NatsURL) if err != nil { return nil, fmt.Errorf("failed to connect to NATS: %w", err) } js, err := jetstream.New(nc) if err != nil { nc.Close() return nil, fmt.Errorf("failed to create JetStream context: %w", err) } store, err := js.ObjectStore(ctx, cfg.BucketName) if err != nil { store, err = js.CreateObjectStore(ctx, jetstream.ObjectStoreConfig{ Bucket: cfg.BucketName, Description: "Binary storage for self-update", }) if err != nil { nc.Close() return nil, fmt.Errorf("failed to get/create object store: %w", err) } log.Printf("Created object store: %s", cfg.BucketName) } return &NATSClient{ Conn: nc, JS: js, Store: store, }, nil } func runUploadAndCleanup(ctx context.Context, cfg *Config) error { client, err := getNATSConnection(ctx, cfg) if err != nil { return err } defer client.Conn.Close() if cfg.Directory != "" { err := filepath.Walk(cfg.Directory, func(path string, info os.FileInfo, err error) error { if err != nil { return err } if info.IsDir() { return nil } data, err := os.ReadFile(path) if err != nil { return fmt.Errorf("failed to read %s: %w", path, err) } relPath, err := filepath.Rel(cfg.Directory, path) if err != nil { return fmt.Errorf("failed to get relative path: %w", err) } objectKey := relPath if cfg.Prefix != "" { objectKey = strings.TrimPrefix(relPath, cfg.Prefix) } objectKey = filepath.ToSlash(objectKey) if cfg.BinaryName == "" { parts := strings.Split(objectKey, "/") if len(parts) >= 2 { cfg.BinaryName = parts[0] } } log.Printf("Uploading %s as %s (%d bytes)", path, objectKey, len(data)) _, err = client.Store.PutBytes(ctx, objectKey, data) if err != nil { return fmt.Errorf("failed to upload %s: %w", path, err) } log.Printf("✓ Uploaded %s", objectKey) return nil }) if err != nil { return fmt.Errorf("failed to upload files: %w", err) } log.Printf("Successfully uploaded all files from %s to NATS object store '%s'", cfg.Directory, cfg.BucketName) } if cfg.Cleanup > 0 { log.Printf("Cleaning up old versions, keeping %d most recent", cfg.Cleanup) err := cleanupOldVersions(ctx, client.Store, cfg.BinaryName, cfg.Cleanup, cfg.CleanupAll) if err != nil { return fmt.Errorf("failed to cleanup old versions: %w", err) } } if !cfg.SkipNotify && cfg.NotifyTopic != "" { log.Printf("Publishing update notification to topic: %s", cfg.NotifyTopic) message := fmt.Sprintf("binaries updated in %s", cfg.BucketName) err := client.Conn.Publish(cfg.NotifyTopic, []byte(message)) if err != nil { return fmt.Errorf("failed to publish notification: %w", err) } // Flush to ensure message is sent err = client.Conn.Flush() if err != nil { return fmt.Errorf("failed to flush notification: %w", err) } log.Printf("✓ Published update notification") } return nil } func runCleanupOnly(ctx context.Context, cfg *Config) error { client, err := getNATSConnection(ctx, cfg) if err != nil { return err } defer client.Conn.Close() log.Printf("Cleaning up old versions, keeping %d most recent", cfg.Cleanup) err = cleanupOldVersions(ctx, client.Store, cfg.BinaryName, cfg.Cleanup, cfg.CleanupAll) if err != nil { return fmt.Errorf("failed to cleanup old versions: %w", err) } return nil } func cleanupOldVersions(ctx context.Context, store jetstream.ObjectStore, currentBinary string, keepCount int, cleanAll bool) error { objects, err := store.List(ctx) if err != nil { return fmt.Errorf("failed to list objects: %w", err) } // Group objects by binary/architecture path // Expected structure: binary/arch/version versionsByPath := make(map[string][]*jetstream.ObjectInfo) for _, obj := range objects { parts := strings.Split(obj.Name, "/") if len(parts) < 3 { // Not a version path, skip continue } binName := parts[0] arch := parts[1] pathKey := binName + "/" + arch // If not cleaning all and this isn't the current binary, skip if !cleanAll && currentBinary != "" && binName != currentBinary { continue } versionsByPath[pathKey] = append(versionsByPath[pathKey], obj) } // For each binary/arch combination, keep only the most recent N versions for pathKey, versions := range versionsByPath { if len(versions) <= keepCount { log.Printf("Path %s has %d versions, keeping all", pathKey, len(versions)) continue } // Sort by semantic version (newest first) sort.Slice(versions, func(i, j int) bool { // Extract version from path: binary/arch/version versionI := filepath.Base(versions[i].Name) versionJ := filepath.Base(versions[j].Name) // Ensure versions start with 'v' for semver.Compare if !strings.HasPrefix(versionI, "v") { versionI = "v" + versionI } if !strings.HasPrefix(versionJ, "v") { versionJ = "v" + versionJ } // semver.Compare returns -1, 0, or 1 // We want newest first, so reverse the comparison return semver.Compare(versionI, versionJ) > 0 }) // Delete old versions (everything after keepCount) toDelete := versions[keepCount:] log.Printf("Path %s has %d versions, deleting %d old versions", pathKey, len(versions), len(toDelete)) for _, obj := range toDelete { version := filepath.Base(obj.Name) log.Printf("Deleting old version: %s (version: %s)", obj.Name, version) err := store.Delete(ctx, obj.Name) if err != nil && !errors.Is(err, jetstream.ErrObjectNotFound) { return fmt.Errorf("failed to delete %s: %w", obj.Name, err) } log.Printf("✓ Deleted %s", obj.Name) } } return nil } func main() { ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) defer cancel() if err := rootCmd.ExecuteContext(ctx); err != nil { _, _ = fmt.Fprintln(os.Stderr, err) os.Exit(1) } }