Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
96 changes: 71 additions & 25 deletions instrumenter.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"go/token"
"go/types"
"os"
"path"
"path/filepath"
"reflect"
"runtime"
Expand Down Expand Up @@ -106,7 +107,7 @@ func (i *instrumenter) instrument(srcDir, singleFile, dstDir string) bool {
i.instrumentFile(name, file, dstDir)
})
}
i.writeGobcoFiles(dstDir, pkgs)
i.writeGobcoFiles(srcDir, dstDir, pkgs)
return true
}

Expand Down Expand Up @@ -629,7 +630,7 @@ var fixedTemplate string
//go:embed templates/gobco_no_testmain_test.go
var noTestMainTemplate string

func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) {
func (i *instrumenter) writeGobcoFiles(srcDir, tmpDir string, pkgs []*ast.Package) {
pkgname := pkgs[0].Name
fixPkgname := func(str string) string {
str = strings.TrimPrefix(str, "//go:build ignore\n// +build ignore\n\n")
Expand All @@ -642,7 +643,7 @@ func (i *instrumenter) writeGobcoFiles(tmpDir string, pkgs []*ast.Package) {
writeFile(filepath.Join(tmpDir, "gobco_no_testmain_test.go"), fixPkgname(noTestMainTemplate))
}

i.writeGobcoBlackBox(pkgs, tmpDir)
i.writeGobcoBlackBox(pkgs, srcDir, tmpDir)
}

func (i *instrumenter) writeGobcoGo(filename, pkgname string) {
Expand All @@ -667,35 +668,80 @@ func (i *instrumenter) writeGobcoGo(filename, pkgname string) {
writeFile(filename, sb.String())
}

// findPackagePath returns the canonical Go import path of the package
// located at srcDir, by combining the enclosing module path
// (read from go.mod) with the relative directory.
func findPackagePath(srcDir string) (string, error) {
moduleRoot, moduleRel, err := findInModule(srcDir)
if err != nil {
return "", err
}
if moduleRoot == "" {
return "", fmt.Errorf("no go.mod found for %q", srcDir)
}

moduleName, err := readModuleName(moduleRoot)
if err != nil {
return "", err
}

if moduleRel == "." {
return moduleName, nil
}
// Import paths always use '/' as separator,
// while moduleRel uses the OS separator.
return path.Join(moduleName, filepath.ToSlash(moduleRel)), nil
}

// readModuleName parses the 'module' directive from go.mod in moduleRoot.
//
// Reading go.mod directly avoids depending on the 'go' command at instrumentation time
// and on any external module like golang.org/x/mod.
func readModuleName(moduleRoot string) (string, error) {
data, err := os.ReadFile(filepath.Join(moduleRoot, "go.mod"))
if err != nil {
return "", err
}

for _, line := range strings.Split(string(data), "\n") {
trimmed := strings.TrimSpace(line)
if !strings.HasPrefix(trimmed, "module") {
continue
}
rest := strings.TrimSpace(trimmed[len("module"):])
if rest == "" {
continue
}
// The module path may optionally be quoted.
if rest[0] == '"' || rest[0] == '`' {
unquoted, err := strconv.Unquote(rest)
if err != nil {
return "", fmt.Errorf("parsing module directive: %w", err)
}
return unquoted, nil
}
// Strip any inline comment.
if idx := strings.Index(rest, "//"); idx >= 0 {
rest = strings.TrimSpace(rest[:idx])
}
return rest, nil
}

return "", fmt.Errorf("no module directive in %s", filepath.Join(moduleRoot, "go.mod"))
}

// writeGobcoBlackBox makes the function 'GobcoCover' available
// to black box tests (those in 'package x_test' instead of 'package x')
// by delegating to the function of the same name in the main package.
func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, dstDir string) {
func (i *instrumenter) writeGobcoBlackBox(pkgs []*ast.Package, srcDir, dstDir string) {
if len(pkgs) < 2 {
return
}

// Copy the 'import' directive from one of the existing files.
pkgName, pkgPath := "", ""
for _, pkg := range pkgs {
forEachFile(pkg, func(name string, file *ast.File) {
for _, imp := range file.Imports {
var impName string
p, err := strconv.Unquote(imp.Path.Value)
ok(err)
if imp.Name != nil {
impName = imp.Name.Name
} else {
impName = filepath.Base(p)
}

if impName == pkgs[0].Name {
pkgName = impName
pkgPath = p
}
}
})
}
pkgPath, err := findPackagePath(srcDir)
ok(err)
// Import paths use '/', so path.Base is correct on all platforms.
pkgName := path.Base(pkgPath)

text := "" +
"package " + pkgs[0].Name + "_test\n" +
Expand Down
26 changes: 20 additions & 6 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,26 +205,40 @@ func (g *gobco) gopaths() string {
}

func (g *gobco) findInModule(dir string) (moduleRoot, moduleRel string) {
absDir, err := filepath.Abs(dir)
moduleRoot, moduleRel, err := findInModule(dir)
g.check(err)
return moduleRoot, moduleRel
}

// findInModule searches dir and its ancestors for a go.mod file.
// It returns the directory containing that file (moduleRoot)
// and the relative path from moduleRoot to dir (moduleRel).
// If no go.mod is found, both returned paths are empty.
func findInModule(dir string) (moduleRoot, moduleRel string, err error) {
absDir, err := filepath.Abs(dir)
if err != nil {
return "", "", err
}

abs := absDir
for {
if _, err := os.Lstat(filepath.Join(abs, "go.mod")); err == nil {
rel, err := filepath.Rel(abs, absDir)
g.check(err)
if _, statErr := os.Lstat(filepath.Join(abs, "go.mod")); statErr == nil {
rel, relErr := filepath.Rel(abs, absDir)
if relErr != nil {
return "", "", relErr
}

root := abs
if rel == "." {
root = dir
}

return root, rel
return root, rel, nil
}

parent := filepath.Dir(abs)
if parent == abs {
return "", ""
return "", "", nil
}
abs = parent
}
Expand Down
20 changes: 20 additions & 0 deletions main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -579,3 +579,23 @@ func Test_gobcoMain__issue38(t *testing.T) {
})
s.CheckEquals(stderr, "")
}

// Test_gobcoMain__issue33 covers the case where the target package's
// own directory contains an import of another package with the same
// base name (here, "iface/target"), and no '_test' file imports the
// target package directly. The previous implementation scanned imports
// in the directory and would pick the wrong import path, producing a
// gobco_bridge_test.go that failed to compile.
func Test_gobcoMain__issue33(t *testing.T) {
s := NewSuite(t)
defer s.TearDownTest()

stdout, stderr := s.RunMain(0, "gobco", "./testdata/issue33/target")

s.CheckEquals(s.GobcoLines(stdout), []string{
"Condition coverage: 1/2",
"testdata/issue33/target/target.go:6:5: " +
"condition \"a > 0\" was once true but never false",
})
s.CheckEquals(stderr, "")
}
10 changes: 10 additions & 0 deletions testdata/issue33/helper/helper.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
// Package helper provides an indirection used by the issue #33
// regression test, so that no file in the target package's directory
// directly imports the target package by its real import path.
package helper

import "github.com/rillig/gobco/testdata/issue33/target"

func AddViaTarget(a, b int) int {
return target.New().Add(a, b)
}
8 changes: 8 additions & 0 deletions testdata/issue33/iface/target/iface.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
// Package target intentionally shares its name
// with github.com/rillig/gobco/testdata/issue33/target,
// to reproduce the import-resolution bug from issue #33.
package target

type Adder interface {
Add(a, b int) int
}
11 changes: 11 additions & 0 deletions testdata/issue33/target/factory.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package target

// The unaliased import here has filepath.Base("iface/target") == "target",
// which equals the current package name.
// The old import-scanning logic in writeGobcoBlackBox mistook this
// for the target package's own import path.
import "github.com/rillig/gobco/testdata/issue33/iface/target"

func New() target.Adder {
return &T{}
}
10 changes: 10 additions & 0 deletions testdata/issue33/target/target.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package target

type T struct{}

func (t *T) Add(a, b int) int {
if a > 0 {
return a + b
}
return b
}
13 changes: 13 additions & 0 deletions testdata/issue33/target/target_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package target_test

import (
"testing"

"github.com/rillig/gobco/testdata/issue33/helper"
)

func TestAdd(t *testing.T) {
if helper.AddViaTarget(1, 2) != 3 {
t.Fail()
}
}