Skip to content

Commit 7bb4ca8

Browse files
committed
allow config struct to be watched for changes on disk
1 parent 8e5a256 commit 7bb4ca8

File tree

3 files changed

+283
-24
lines changed

3 files changed

+283
-24
lines changed

internal/cmdutil/preparers/preparers.go

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,7 @@ package preparers
22

33
import (
44
"context"
5-
"errors"
65
"fmt"
7-
"io/fs"
86
"net/http"
97
"os"
108
"path/filepath"
@@ -33,20 +31,11 @@ type Preparer func(context.Context) (context.Context, error)
3331
func LoadConfig(ctx context.Context) (context.Context, error) {
3432
logger := logger.FromContext(ctx)
3533

36-
cfg := config.New()
37-
38-
// Apply config from the config file, if it exists
39-
path := filepath.Join(state.ConfigDirectory(ctx), config.FileName)
40-
if err := cfg.ApplyFile(path); err != nil && !errors.Is(err, fs.ErrNotExist) {
34+
cfg, err := config.Load(ctx, filepath.Join(state.ConfigDirectory(ctx), config.FileName))
35+
if err != nil {
4136
return nil, err
4237
}
4338

44-
// Apply config from the environment, overriding anything from the file
45-
cfg.ApplyEnv()
46-
47-
// Finally, apply command line options, overriding any previous setting
48-
cfg.ApplyFlags(flagctx.FromContext(ctx))
49-
5039
logger.Debug("config initialized.")
5140

5241
return config.NewContext(ctx, cfg), nil

internal/config/config.go

Lines changed: 148 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
11
package config
22

33
import (
4+
"context"
5+
"errors"
6+
"io/fs"
47
"sync"
8+
"time"
59

10+
"github.com/fsnotify/fsnotify"
611
"github.com/spf13/pflag"
712

813
"github.com/superfly/fly-go/tokens"
914
"github.com/superfly/flyctl/internal/env"
15+
"github.com/superfly/flyctl/internal/flag/flagctx"
1016
"github.com/superfly/flyctl/internal/flag/flagnames"
17+
"github.com/superfly/flyctl/internal/task"
1118
)
1219

1320
const (
@@ -47,7 +54,12 @@ const (
4754
//
4855
// Instances of Config are safe for concurrent use.
4956
type Config struct {
50-
mu sync.RWMutex
57+
mu sync.RWMutex
58+
path string
59+
60+
watchOnce sync.Once
61+
watchErr error
62+
subs map[chan *Config]struct{}
5163

5264
// APIBaseURL denotes the base URL of the API.
5365
APIBaseURL string
@@ -93,22 +105,34 @@ type Config struct {
93105
MetricsToken string
94106
}
95107

96-
// New returns a new instance of Config populated with default values.
97-
func New() *Config {
98-
return &Config{
108+
func Load(ctx context.Context, path string) (*Config, error) {
109+
cfg := &Config{
99110
APIBaseURL: defaultAPIBaseURL,
100111
FlapsBaseURL: defaultFlapsBaseURL,
101112
RegistryHost: defaultRegistryHost,
102113
MetricsBaseURL: defaultMetricsBaseURL,
103114
Tokens: new(tokens.Tokens),
104115
}
116+
117+
// Apply config from the config file, if it exists
118+
if err := cfg.applyFile(path); err != nil && !errors.Is(err, fs.ErrNotExist) {
119+
return nil, err
120+
}
121+
122+
// Apply config from the environment, overriding anything from the file
123+
cfg.applyEnv()
124+
125+
// Finally, apply command line options, overriding any previous setting
126+
cfg.applyFlags(flagctx.FromContext(ctx))
127+
128+
return cfg, nil
105129
}
106130

107-
// ApplyEnv sets the properties of cfg which may be set via environment
131+
// applyEnv sets the properties of cfg which may be set via environment
108132
// variables to the values these variables contain.
109133
//
110-
// ApplyEnv does not change the dirty state of config.
111-
func (cfg *Config) ApplyEnv() {
134+
// applyEnv does not change the dirty state of config.
135+
func (cfg *Config) applyEnv() {
112136
cfg.mu.Lock()
113137
defer cfg.mu.Unlock()
114138

@@ -131,12 +155,14 @@ func (cfg *Config) ApplyEnv() {
131155
cfg.SendMetrics = env.IsTruthy(SendMetricsEnvKey) || cfg.SendMetrics
132156
}
133157

134-
// ApplyFile sets the properties of cfg which may be set via configuration file
158+
// applyFile sets the properties of cfg which may be set via configuration file
135159
// to the values the file at the given path contains.
136-
func (cfg *Config) ApplyFile(path string) (err error) {
160+
func (cfg *Config) applyFile(path string) (err error) {
137161
cfg.mu.Lock()
138162
defer cfg.mu.Unlock()
139163

164+
cfg.path = path
165+
140166
var w struct {
141167
AccessToken string `yaml:"access_token"`
142168
MetricsToken string `yaml:"metrics_token"`
@@ -158,9 +184,9 @@ func (cfg *Config) ApplyFile(path string) (err error) {
158184
return
159185
}
160186

161-
// ApplyFlags sets the properties of cfg which may be set via command line flags
187+
// applyFlags sets the properties of cfg which may be set via command line flags
162188
// to the values the flags of the given FlagSet may contain.
163-
func (cfg *Config) ApplyFlags(fs *pflag.FlagSet) {
189+
func (cfg *Config) applyFlags(fs *pflag.FlagSet) {
164190
cfg.mu.Lock()
165191
defer cfg.mu.Unlock()
166192

@@ -188,6 +214,117 @@ func (cfg *Config) MetricsBaseURLIsProduction() bool {
188214
return cfg.MetricsBaseURL == defaultMetricsBaseURL
189215
}
190216

217+
func (cfg *Config) Watch(ctx context.Context) (chan *Config, error) {
218+
cfg.watchOnce.Do(func() {
219+
watch, err := fsnotify.NewWatcher()
220+
if err != nil {
221+
cfg.watchErr = err
222+
return
223+
}
224+
225+
if err := watch.Add(cfg.path); err != nil {
226+
cfg.watchErr = err
227+
return
228+
}
229+
230+
cfg.subs = make(map[chan *Config]struct{})
231+
232+
task.FromContext(ctx).Run(func(ctx context.Context) {
233+
ctx, cancel := context.WithCancel(ctx)
234+
defer cancel()
235+
236+
cleanupDone := make(chan struct{})
237+
defer func() { <-cleanupDone }()
238+
239+
go func() {
240+
defer close(cleanupDone)
241+
242+
<-ctx.Done()
243+
244+
cfg.mu.Lock()
245+
defer cfg.mu.Unlock()
246+
247+
cfg.watchErr = errors.Join(cfg.watchErr, ctx.Err(), watch.Close())
248+
249+
for sub := range cfg.subs {
250+
close(sub)
251+
}
252+
cfg.subs = nil
253+
}()
254+
255+
for {
256+
select {
257+
case e, open := <-watch.Events:
258+
if !open {
259+
return
260+
}
261+
262+
if !e.Has(fsnotify.Write) {
263+
continue
264+
}
265+
266+
go cfg.notifySubs(ctx)
267+
case err := <-watch.Errors:
268+
cfg.mu.Lock()
269+
defer cfg.mu.Unlock()
270+
271+
cfg.watchErr = errors.Join(cfg.watchErr, err)
272+
273+
return
274+
case <-ctx.Done():
275+
return
276+
}
277+
}
278+
})
279+
})
280+
281+
cfg.mu.Lock()
282+
defer cfg.mu.Unlock()
283+
284+
if cfg.watchErr != nil {
285+
return nil, cfg.watchErr
286+
}
287+
288+
sub := make(chan *Config)
289+
cfg.subs[sub] = struct{}{}
290+
291+
return sub, nil
292+
}
293+
294+
func (cfg *Config) Unwatch(sub chan *Config) {
295+
cfg.mu.Lock()
296+
defer cfg.mu.Unlock()
297+
298+
if cfg.subs != nil {
299+
delete(cfg.subs, sub)
300+
close(sub)
301+
}
302+
}
303+
304+
func (cfg *Config) notifySubs(ctx context.Context) {
305+
newCfg, err := Load(ctx, cfg.path)
306+
if err != nil {
307+
return
308+
}
309+
310+
cfg.mu.RLock()
311+
defer cfg.mu.RUnlock()
312+
313+
// just in case we have a slow subscriber
314+
timer := time.NewTimer(100 * time.Millisecond)
315+
defer timer.Stop()
316+
317+
for sub := range cfg.subs {
318+
select {
319+
case sub <- newCfg:
320+
case <-timer.C:
321+
return
322+
case <-ctx.Done():
323+
return
324+
}
325+
}
326+
}
327+
191328
func applyStringFlags(fs *pflag.FlagSet, flags map[string]*string) {
192329
for name, dst := range flags {
193330
if !fs.Changed(name) {

internal/config/config_test.go

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
package config
2+
3+
import (
4+
"context"
5+
"errors"
6+
"io"
7+
"os"
8+
"path"
9+
"sync"
10+
"testing"
11+
"time"
12+
13+
"github.com/spf13/pflag"
14+
"github.com/stretchr/testify/assert"
15+
"github.com/stretchr/testify/require"
16+
"github.com/superfly/flyctl/flyctl"
17+
"github.com/superfly/flyctl/internal/flag/flagctx"
18+
"github.com/superfly/flyctl/internal/logger"
19+
"github.com/superfly/flyctl/internal/task"
20+
)
21+
22+
func TestConfigWatch(t *testing.T) {
23+
cfgDirWas, cfgDirWasSet := os.LookupEnv("FLY_CONFIG_DIR")
24+
os.Setenv("FLY_CONFIG_DIR", t.TempDir())
25+
flyctl.InitConfig()
26+
t.Cleanup(func() {
27+
if cfgDirWasSet {
28+
os.Setenv("FLY_CONFIG_DIR", cfgDirWas)
29+
} else {
30+
os.Unsetenv("FLY_CONFIG_DIR")
31+
}
32+
})
33+
34+
ctx, cancel := context.WithCancel(context.Background())
35+
defer cancel()
36+
37+
ctx = logger.NewContext(ctx, logger.New(io.Discard, logger.Error, false))
38+
ctx = flagctx.NewContext(ctx, new(pflag.FlagSet))
39+
40+
tm := task.New()
41+
tm.Start(ctx)
42+
ctx = task.WithContext(ctx, tm)
43+
44+
path := path.Join(t.TempDir(), "config.yml")
45+
46+
require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_foo`), 0644))
47+
cfg, err := Load(ctx, path)
48+
require.NoError(t, err)
49+
require.Equal(t, "fo1_foo", cfg.Tokens.All())
50+
51+
c1, err := cfg.Watch(ctx)
52+
require.NoError(t, err)
53+
54+
c2, err := cfg.Watch(ctx)
55+
require.NoError(t, err)
56+
57+
cfgs, errs := getConfigChanges(c1, c2)
58+
require.Equal(t, 2, len(errs))
59+
require.Equal(t, 0, len(cfgs))
60+
61+
require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_bar`), 0644))
62+
63+
cfgs, errs = getConfigChanges(c1, c2)
64+
require.Equal(t, 0, len(errs))
65+
require.Equal(t, 2, len(cfgs))
66+
require.Equal(t, cfgs[0], cfgs[1])
67+
require.Equal(t, "fo1_bar", cfgs[0].Tokens.All())
68+
69+
cfg.Unwatch(c1)
70+
71+
require.NoError(t, os.WriteFile(path, []byte(`access_token: fo1_baz`), 0644))
72+
73+
cfgs, errs = getConfigChanges(c2)
74+
require.Equal(t, 0, len(errs))
75+
require.Equal(t, 1, len(cfgs))
76+
require.Equal(t, "fo1_baz", cfgs[0].Tokens.All())
77+
78+
shutdown := make(chan struct{})
79+
go func() {
80+
defer close(shutdown)
81+
tm.Shutdown()
82+
}()
83+
select {
84+
case <-shutdown:
85+
case <-time.After(50 * time.Millisecond):
86+
t.Fatal("slow shutdown")
87+
}
88+
89+
_, open := <-c1
90+
require.False(t, open)
91+
_, open = <-c2
92+
require.False(t, open)
93+
94+
_, err = cfg.Watch(ctx)
95+
assert.Error(t, err)
96+
require.EqualError(t, err, context.Canceled.Error())
97+
}
98+
99+
func getConfigChanges(chans ...chan *Config) ([]*Config, []error) {
100+
var (
101+
cfgs []*Config
102+
errs []error
103+
m sync.Mutex
104+
wg sync.WaitGroup
105+
)
106+
107+
for _, ch := range chans {
108+
ch := ch
109+
110+
wg.Add(1)
111+
go func() {
112+
defer wg.Done()
113+
defer m.Unlock()
114+
115+
select {
116+
case cfg, open := <-ch:
117+
m.Lock()
118+
if open {
119+
cfgs = append(cfgs, cfg)
120+
} else {
121+
errs = append(errs, errors.New("closed"))
122+
}
123+
case <-time.After(50 * time.Millisecond):
124+
m.Lock()
125+
errs = append(errs, errors.New("timeout"))
126+
}
127+
}()
128+
}
129+
130+
wg.Wait()
131+
132+
return cfgs, errs
133+
}

0 commit comments

Comments
 (0)