Skip to content

Commit 986ba72

Browse files
Klesh WongCopilot
andcommitted
fix: harden dbt pipeline inputs
Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 420e494 commit 986ba72

13 files changed

Lines changed: 361 additions & 35 deletions

File tree

backend/helpers/oidchelper/config.go

Lines changed: 16 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,20 +88,28 @@ func (c *Config) ProviderNames() []string {
8888
}
8989

9090
// LoadConfig reads auth env vars via Viper and validates required fields.
91-
// Returns Config{AuthEnabled:false} when AUTH_ENABLED=false (the default,
92-
// preserves historical behavior).
91+
// AUTH_ENABLED defaults to true unless it is explicitly set to false.
9392
func LoadConfig(basicRes context.BasicRes) (*Config, error) {
9493
cfg := basicRes.GetConfigReader()
9594

96-
if !cfg.GetBool("AUTH_ENABLED") {
95+
authEnabled := true
96+
if cfg.IsSet("AUTH_ENABLED") {
97+
authEnabled = cfg.GetBool("AUTH_ENABLED")
98+
}
99+
if !authEnabled {
97100
return &Config{AuthEnabled: false}, nil
98101
}
99102

103+
oidcEnabled := cfg.GetBool("OIDC_ENABLED")
100104
sessionSecret := strings.TrimSpace(cfg.GetString("SESSION_SECRET"))
101-
if sessionSecret == "" {
102-
return nil, fmt.Errorf("AUTH_ENABLED=true but SESSION_SECRET is not set")
103-
}
104-
if len(sessionSecret) < 32 {
105+
if oidcEnabled {
106+
if sessionSecret == "" {
107+
return nil, fmt.Errorf("OIDC_ENABLED=true but SESSION_SECRET is not set")
108+
}
109+
if len(sessionSecret) < 32 {
110+
return nil, fmt.Errorf("SESSION_SECRET must be at least 32 bytes")
111+
}
112+
} else if sessionSecret != "" && len(sessionSecret) < 32 {
105113
return nil, fmt.Errorf("SESSION_SECRET must be at least 32 bytes")
106114
}
107115

@@ -121,7 +129,7 @@ func LoadConfig(basicRes context.BasicRes) (*Config, error) {
121129

122130
out := &Config{
123131
AuthEnabled: true,
124-
OIDCEnabled: cfg.GetBool("OIDC_ENABLED"),
132+
OIDCEnabled: oidcEnabled,
125133
Providers: map[string]*ProviderConfig{},
126134
LogoutRedirect: cfg.GetBool("OIDC_LOGOUT_REDIRECT"),
127135
SessionSecret: []byte(sessionSecret),

backend/helpers/oidchelper/config_test.go

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,13 @@ package oidchelper
2020
import (
2121
"reflect"
2222
"testing"
23+
24+
"github.com/spf13/viper"
25+
26+
"github.com/apache/incubator-devlake/core/config"
27+
corectx "github.com/apache/incubator-devlake/core/context"
28+
"github.com/apache/incubator-devlake/core/dal"
29+
"github.com/apache/incubator-devlake/core/log"
2330
)
2431

2532
func TestParseScopes(t *testing.T) {
@@ -84,3 +91,43 @@ func TestProviderNamesSorted(t *testing.T) {
8491
t.Errorf("ProviderNames = %v, want sorted [entra google]", names)
8592
}
8693
}
94+
95+
type basicResStub struct {
96+
cfg config.ConfigReader
97+
}
98+
99+
func (b basicResStub) GetConfigReader() config.ConfigReader { return b.cfg }
100+
func (b basicResStub) GetConfig(string) string { return "" }
101+
func (b basicResStub) GetLogger() log.Logger { return nil }
102+
func (b basicResStub) NestedLogger(string) corectx.BasicRes { return nil }
103+
func (b basicResStub) ReplaceLogger(log.Logger) corectx.BasicRes {
104+
return nil
105+
}
106+
func (b basicResStub) GetDal() dal.Dal { return nil }
107+
108+
func TestLoadConfigDefaultsAuthEnabled(t *testing.T) {
109+
v := viper.New()
110+
111+
cfg, err := LoadConfig(basicResStub{cfg: v})
112+
if err != nil {
113+
t.Fatalf("LoadConfig returned error: %v", err)
114+
}
115+
if !cfg.AuthEnabled {
116+
t.Fatal("AuthEnabled should default to true when AUTH_ENABLED is unset")
117+
}
118+
if cfg.OIDCEnabled {
119+
t.Fatal("OIDCEnabled should default to false")
120+
}
121+
if len(cfg.SessionSecret) != 0 {
122+
t.Fatalf("SessionSecret = %q, want empty when OIDC is disabled", string(cfg.SessionSecret))
123+
}
124+
}
125+
126+
func TestLoadConfigRequiresSessionSecretForOIDC(t *testing.T) {
127+
v := viper.New()
128+
v.Set("OIDC_ENABLED", true)
129+
130+
if _, err := LoadConfig(basicResStub{cfg: v}); err == nil {
131+
t.Fatal("LoadConfig should reject OIDC-enabled config without SESSION_SECRET")
132+
}
133+
}

backend/plugins/dbt/dbt.go

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@ var PluginEntry impl.Dbt
2828
// standalone mode for debugging
2929
func main() {
3030
dbtCmd := &cobra.Command{Use: "dbt"}
31-
_ = dbtCmd.MarkFlagRequired("projectPath")
3231
projectPath := dbtCmd.Flags().StringP("projectPath", "p", "/Users/abeizn/demoapp", "user dbt project directory.")
3332
projectGitURL := dbtCmd.Flags().StringP("projectGitURL", "g", "", "user dbt project git url.")
3433
projectName := dbtCmd.Flags().StringP("projectName", "n", "demoapp", "user dbt project name.")

backend/plugins/dbt/impl/impl.go

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,8 @@ limitations under the License.
1818
package impl
1919

2020
import (
21+
"fmt"
22+
2123
"github.com/apache/incubator-devlake/core/dal"
2224
"github.com/apache/incubator-devlake/core/errors"
2325
"github.com/apache/incubator-devlake/core/plugin"
@@ -28,6 +30,7 @@ import (
2830
var _ interface {
2931
plugin.PluginMeta
3032
plugin.PluginTask
33+
plugin.CloseablePluginTask
3134
plugin.PluginModel
3235
} = (*Dbt)(nil)
3336

@@ -54,8 +57,8 @@ func (p Dbt) PrepareTaskData(taskCtx plugin.TaskContext, options map[string]inte
5457
if err != nil {
5558
return nil, err
5659
}
57-
if op.ProjectPath == "" {
58-
return nil, errors.Default.New("projectPath is required for dbt plugin")
60+
if err := tasks.PrepareOptions(&op, taskCtx.GetConfig(tasks.DbtProjectBaseDirConfigKey)); err != nil {
61+
return nil, err
5962
}
6063

6164
if op.ProjectTarget == "" {
@@ -71,6 +74,17 @@ func (p Dbt) RootPkgPath() string {
7174
return "github.com/apache/incubator-devlake/plugins/dbt"
7275
}
7376

77+
func (p Dbt) Close(taskCtx plugin.TaskContext) errors.Error {
78+
data, ok := taskCtx.GetData().(*tasks.DbtTaskData)
79+
if !ok || data == nil || data.Options == nil || !data.Options.ManagedProjectDir {
80+
return nil
81+
}
82+
if err := tasks.CleanupManagedProjectDir(data.Options); err != nil {
83+
return errors.Default.Wrap(err, fmt.Sprintf("cleanup dbt project path %q", data.Options.ProjectPath))
84+
}
85+
return nil
86+
}
87+
7488
func (p Dbt) Name() string {
7589
return "dbt"
7690
}

backend/plugins/dbt/tasks/convertor.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -234,4 +234,5 @@ var DbtConverterMeta = plugin.SubTaskMeta{
234234
EntryPoint: DbtConverter,
235235
EnabledByDefault: true,
236236
Description: "Convert data by dbt",
237+
Dependencies: []*plugin.SubTaskMeta{&GitMeta},
237238
}

backend/plugins/dbt/tasks/git.go

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,26 @@ func Git(taskCtx plugin.SubTaskContext) errors.Error {
3232
return nil
3333
}
3434

35-
// clean ProjectPath
36-
err := os.RemoveAll(data.Options.ProjectPath)
35+
projectBaseDir, err := ensureProjectBaseDir(data.Options.ProjectBaseDir)
3736
if err != nil {
38-
logger.Error(err, "cleanup before clone dbt project failed")
39-
return errors.Convert(err)
37+
logger.Error(err, "prepare dbt workspace failed")
38+
return err
4039
}
40+
projectPath, mkErr := os.MkdirTemp(projectBaseDir, "project-*")
41+
if mkErr != nil {
42+
logger.Error(mkErr, "create managed dbt project directory failed")
43+
return errors.Convert(mkErr)
44+
}
45+
data.Options.ProjectPath = projectPath
4146

42-
// git clone from ProjectGitURL into ProjectPath
47+
// git clone from ProjectGitURL into a managed temporary project directory
4348
cmd := exec.Command("git", "clone", data.Options.ProjectGitURL, data.Options.ProjectPath)
4449
logger.Info("start clone dbt project: %v", cmd)
45-
out, err := cmd.CombinedOutput()
46-
if err != nil {
47-
logger.Error(err, "clone dbt project failed")
48-
return errors.Convert(err)
50+
out, cloneErr := cmd.CombinedOutput()
51+
if cloneErr != nil {
52+
_ = os.RemoveAll(data.Options.ProjectPath)
53+
logger.Error(cloneErr, "clone dbt project failed")
54+
return errors.Convert(cloneErr)
4955
}
5056
logger.Info("clone dbt project success: %v", string(out))
5157
return nil
Lines changed: 132 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,132 @@
1+
/*
2+
Licensed to the Apache Software Foundation (ASF) under one or more
3+
contributor license agreements. See the NOTICE file distributed with
4+
this work for additional information regarding copyright ownership.
5+
The ASF licenses this file to You under the Apache License, Version 2.0
6+
(the "License"); you may not use this file except in compliance with
7+
the License. You may obtain a copy of the License at
8+
9+
http://www.apache.org/licenses/LICENSE-2.0
10+
11+
Unless required by applicable law or agreed to in writing, software
12+
distributed under the License is distributed on an "AS IS" BASIS,
13+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
See the License for the specific language governing permissions and
15+
limitations under the License.
16+
*/
17+
18+
package tasks
19+
20+
import (
21+
"net/url"
22+
"os"
23+
"path/filepath"
24+
"strings"
25+
26+
"github.com/apache/incubator-devlake/core/errors"
27+
)
28+
29+
const (
30+
DbtProjectBaseDirConfigKey = "DBT_PROJECTS_DIR"
31+
dbtProjectBaseDirName = "devlake-dbt-projects"
32+
)
33+
34+
func PrepareOptions(op *DbtOptions, configuredBaseDir string) errors.Error {
35+
if op == nil {
36+
return errors.Default.New("dbt options are required")
37+
}
38+
39+
baseDir, err := normalizeBaseDir(configuredBaseDir)
40+
if err != nil {
41+
return err
42+
}
43+
op.ProjectBaseDir = baseDir
44+
op.ProjectPath = strings.TrimSpace(op.ProjectPath)
45+
op.ProjectGitURL = strings.TrimSpace(op.ProjectGitURL)
46+
47+
if op.ProjectGitURL != "" {
48+
if err := validateProjectGitURL(op.ProjectGitURL); err != nil {
49+
return err
50+
}
51+
op.ProjectPath = ""
52+
op.ManagedProjectDir = true
53+
return nil
54+
}
55+
56+
if op.ProjectPath == "" {
57+
return errors.Default.New("projectPath is required for local dbt projects")
58+
}
59+
60+
projectPath, err := normalizePathWithinBase(baseDir, op.ProjectPath)
61+
if err != nil {
62+
return err
63+
}
64+
op.ProjectPath = projectPath
65+
return nil
66+
}
67+
68+
func normalizeBaseDir(configuredBaseDir string) (string, errors.Error) {
69+
baseDir := strings.TrimSpace(configuredBaseDir)
70+
if baseDir == "" {
71+
baseDir = filepath.Join(os.TempDir(), dbtProjectBaseDirName)
72+
}
73+
baseDir, err := filepath.Abs(filepath.Clean(baseDir))
74+
if err != nil {
75+
return "", errors.Convert(err)
76+
}
77+
return baseDir, nil
78+
}
79+
80+
func normalizePathWithinBase(baseDir string, candidate string) (string, errors.Error) {
81+
normalizedPath, err := filepath.Abs(filepath.Clean(candidate))
82+
if err != nil {
83+
return "", errors.Convert(err)
84+
}
85+
rel, err := filepath.Rel(baseDir, normalizedPath)
86+
if err != nil {
87+
return "", errors.Convert(err)
88+
}
89+
if rel == ".." || strings.HasPrefix(rel, ".."+string(os.PathSeparator)) {
90+
return "", errors.Default.New("projectPath must stay within " + baseDir)
91+
}
92+
return normalizedPath, nil
93+
}
94+
95+
func validateProjectGitURL(rawURL string) errors.Error {
96+
u, err := url.Parse(rawURL)
97+
if err != nil {
98+
return errors.Convert(err)
99+
}
100+
if u.Scheme != "https" && u.Scheme != "ssh" {
101+
return errors.Default.New("projectGitURL must use https:// or ssh://")
102+
}
103+
if u.Host == "" {
104+
return errors.Default.New("projectGitURL must include a hostname")
105+
}
106+
if u.Path == "" || u.Path == "/" {
107+
return errors.Default.New("projectGitURL must include a repository path")
108+
}
109+
return nil
110+
}
111+
112+
func ensureProjectBaseDir(baseDir string) (string, errors.Error) {
113+
baseDir, err := normalizeBaseDir(baseDir)
114+
if err != nil {
115+
return "", err
116+
}
117+
if err := os.MkdirAll(baseDir, 0o755); err != nil {
118+
return "", errors.Convert(err)
119+
}
120+
return baseDir, nil
121+
}
122+
123+
func CleanupManagedProjectDir(op *DbtOptions) errors.Error {
124+
if op == nil || !op.ManagedProjectDir || op.ProjectPath == "" {
125+
return nil
126+
}
127+
projectPath, err := normalizePathWithinBase(op.ProjectBaseDir, op.ProjectPath)
128+
if err != nil {
129+
return err
130+
}
131+
return errors.Convert(os.RemoveAll(projectPath))
132+
}

0 commit comments

Comments
 (0)