Skip to content

Commit db422d9

Browse files
git-hulkclaude
andcommitted
Add drift test keeping ASTVisitor and Walk traversals in sync
Accept/ASTVisitor and Walk/WalkFunc are two independent traversal engines that encode every node's children separately, so a new AST node type added to one can silently be missing from the other. Add a static test that parses the package source and asserts every type with an Accept method has both a Visit method on the ASTVisitor interface and a case in Walk's type switch, and vice versa. The test immediately caught four types missing from Walk - BoolLiteral, CreateNamedCollection, NamedCollectionParam and WindowFrameParam - and their cases are added. Also fix a confirmed traversal bug: JoinTableExpr.Accept returned early after visiting the sample ratio, so VisitJoinTableExpr was never called for table expressions carrying a SAMPLE clause. Add a regression test. Co-Authored-By: Claude Fable 5 <noreply@anthropic.com>
1 parent 3c3822c commit db422d9

3 files changed

Lines changed: 165 additions & 1 deletion

File tree

parser/ast.go

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,9 @@ func (j *JoinTableExpr) Accept(visitor ASTVisitor) error {
176176
return err
177177
}
178178
if j.SampleRatio != nil {
179-
return j.SampleRatio.Accept(visitor)
179+
if err := j.SampleRatio.Accept(visitor); err != nil {
180+
return err
181+
}
180182
}
181183
return visitor.VisitJoinTableExpr(j)
182184
}

parser/traversal_drift_test.go

Lines changed: 137 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
package parser
2+
3+
import (
4+
"go/ast"
5+
goparser "go/parser"
6+
"go/token"
7+
"os"
8+
"sort"
9+
"strings"
10+
"testing"
11+
12+
"github.com/stretchr/testify/require"
13+
)
14+
15+
// The package has two independent traversal engines: Accept/ASTVisitor and
16+
// Walk/WalkFunc. Each encodes every node's children separately, so a new AST
17+
// node type added to one can silently be forgotten in the other. This test
18+
// statically asserts that every type with an Accept method also has a Visit
19+
// method on the ASTVisitor interface and a case in Walk's type switch (and
20+
// vice versa), so the engines cannot drift.
21+
func TestTraversalEnginesCoverSameNodeTypes(t *testing.T) {
22+
entries, err := os.ReadDir(".")
23+
require.NoError(t, err)
24+
25+
fset := token.NewFileSet()
26+
acceptTypes := map[string]bool{}
27+
visitorTypes := map[string]bool{}
28+
walkTypes := map[string]bool{}
29+
30+
for _, entry := range entries {
31+
name := entry.Name()
32+
if !strings.HasSuffix(name, ".go") || strings.HasSuffix(name, "_test.go") {
33+
continue
34+
}
35+
file, err := goparser.ParseFile(fset, name, nil, 0)
36+
require.NoError(t, err)
37+
for _, decl := range file.Decls {
38+
switch d := decl.(type) {
39+
case *ast.FuncDecl:
40+
switch {
41+
case d.Name.Name == "Accept" && d.Recv != nil && len(d.Recv.List) == 1:
42+
if star, ok := d.Recv.List[0].Type.(*ast.StarExpr); ok {
43+
if ident, ok := star.X.(*ast.Ident); ok {
44+
acceptTypes[ident.Name] = true
45+
}
46+
}
47+
case d.Name.Name == "Walk" && d.Recv == nil:
48+
ast.Inspect(d.Body, func(n ast.Node) bool {
49+
cc, ok := n.(*ast.CaseClause)
50+
if !ok {
51+
return true
52+
}
53+
for _, expr := range cc.List {
54+
if star, ok := expr.(*ast.StarExpr); ok {
55+
if ident, ok := star.X.(*ast.Ident); ok {
56+
walkTypes[ident.Name] = true
57+
}
58+
}
59+
}
60+
return true
61+
})
62+
}
63+
case *ast.GenDecl:
64+
for _, spec := range d.Specs {
65+
ts, ok := spec.(*ast.TypeSpec)
66+
if !ok || ts.Name.Name != "ASTVisitor" {
67+
continue
68+
}
69+
iface, ok := ts.Type.(*ast.InterfaceType)
70+
if !ok {
71+
continue
72+
}
73+
for _, method := range iface.Methods.List {
74+
ft, ok := method.Type.(*ast.FuncType)
75+
if !ok || ft.Params == nil || len(ft.Params.List) != 1 {
76+
continue
77+
}
78+
if star, ok := ft.Params.List[0].Type.(*ast.StarExpr); ok {
79+
if ident, ok := star.X.(*ast.Ident); ok {
80+
visitorTypes[ident.Name] = true
81+
}
82+
}
83+
}
84+
}
85+
}
86+
}
87+
}
88+
89+
require.NotEmpty(t, acceptTypes)
90+
require.NotEmpty(t, visitorTypes)
91+
require.NotEmpty(t, walkTypes)
92+
93+
require.Empty(t, diffSet(acceptTypes, visitorTypes),
94+
"types with an Accept method but no ASTVisitor Visit method")
95+
require.Empty(t, diffSet(visitorTypes, acceptTypes),
96+
"types with an ASTVisitor Visit method but no Accept method")
97+
require.Empty(t, diffSet(acceptTypes, walkTypes),
98+
"types with an Accept method but no case in Walk's type switch")
99+
require.Empty(t, diffSet(walkTypes, acceptTypes),
100+
"types with a case in Walk's type switch but no Accept method")
101+
}
102+
103+
// diffSet returns the members of a that are not in b, sorted.
104+
func diffSet(a, b map[string]bool) []string {
105+
var diff []string
106+
for name := range a {
107+
if !b[name] {
108+
diff = append(diff, name)
109+
}
110+
}
111+
sort.Strings(diff)
112+
return diff
113+
}
114+
115+
// VisitJoinTableExpr must be called even when the table expression carries a
116+
// SAMPLE clause; Accept used to return early after visiting the sample ratio.
117+
func TestVisitJoinTableExprWithSampleRatio(t *testing.T) {
118+
stmts, err := NewParser("SELECT * FROM t SAMPLE 1/10").ParseStmts()
119+
require.NoError(t, err)
120+
require.Len(t, stmts, 1)
121+
122+
var visitedJoinTable, visitedSample bool
123+
visitor := &DefaultASTVisitor{
124+
Visit: func(expr Expr) error {
125+
switch expr.(type) {
126+
case *JoinTableExpr:
127+
visitedJoinTable = true
128+
case *SampleClause:
129+
visitedSample = true
130+
}
131+
return nil
132+
},
133+
}
134+
require.NoError(t, stmts[0].Accept(visitor))
135+
require.True(t, visitedSample, "SampleClause was not visited")
136+
require.True(t, visitedJoinTable, "JoinTableExpr was not visited")
137+
}

parser/walk.go

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ func Walk(node Expr, fn WalkFunc) bool {
122122
// Leaf node
123123
case *StringLiteral:
124124
// Leaf node
125+
case *BoolLiteral:
126+
// Leaf node
125127
case *NullLiteral:
126128
// Leaf node
127129
case *NotNullLiteral:
@@ -431,6 +433,10 @@ func Walk(node Expr, fn WalkFunc) bool {
431433
if !Walk(n.Number, fn) {
432434
return false
433435
}
436+
case *WindowFrameParam:
437+
if !Walk(n.Param, fn) {
438+
return false
439+
}
434440
case *TopClause:
435441
if !Walk(n.Number, fn) {
436442
return false
@@ -740,6 +746,25 @@ func Walk(node Expr, fn WalkFunc) bool {
740746
if !Walk(n.Expr, fn) {
741747
return false
742748
}
749+
case *CreateNamedCollection:
750+
if !Walk(n.Name, fn) {
751+
return false
752+
}
753+
if !Walk(n.OnCluster, fn) {
754+
return false
755+
}
756+
for _, param := range n.Params {
757+
if !Walk(param, fn) {
758+
return false
759+
}
760+
}
761+
case *NamedCollectionParam:
762+
if !Walk(n.Name, fn) {
763+
return false
764+
}
765+
if !Walk(n.Value, fn) {
766+
return false
767+
}
743768
case *CreateRole:
744769
for _, name := range n.RoleNames {
745770
if !Walk(name, fn) {

0 commit comments

Comments
 (0)