Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions .goreleaser.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,10 @@ dockers:
goos: linux
goarch: amd64
dockerfile: Dockerfile
skip_push: auto
image_templates:
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:latest-amd64{{ end }}'
- 'foomo/gotsrpc:{{ .Version }}-amd64'
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:{{ .Version }}-amd64{{ end }}'
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:{{ .Major }}-amd64{{ end }}'
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:{{ .Major }}.{{ .Minor }}-amd64{{ end }}'
build_flag_templates:
Expand All @@ -70,9 +71,10 @@ dockers:
goos: linux
goarch: arm64
dockerfile: Dockerfile
skip_push: auto
image_templates:
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:latest-arm64{{ end }}'
- 'foomo/gotsrpc:{{ .Version }}-arm64'
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:{{ .Version }}-arm64{{ end }}'
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:{{ .Major }}-arm64{{ end }}'
- '{{ if eq .Prerelease "" }}foomo/gotsrpc:{{ .Major }}.{{ .Minor }}-arm64{{ end }}'
build_flag_templates:
Expand Down
264 changes: 248 additions & 16 deletions internal/parser/servicereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,22 @@ import (
"github.com/foomo/gotsrpc/v2/internal/model"
)

// interfaceInfo holds a parsed interface type and its file imports.
type interfaceInfo struct {
iface *ast.InterfaceType
imports fileImportSpecMap
typeParams []string
}

// resolvedMethod pairs an AST function type with the file imports it was declared in.
type resolvedMethod struct {
name string
funcTyp *ast.FuncType
imports fileImportSpecMap
typeSubst map[string]ast.Expr
substImports fileImportSpecMap // imports for resolving typeSubst expressions
}

func Read(
goPaths []string,
gomod config.Namespace,
Expand Down Expand Up @@ -149,7 +165,7 @@ func Read(
return
}

func readServiceFile(file *ast.File, packageName string, services model.ServiceList) error {
func readServiceFile(file *ast.File, packageName string, services model.ServiceList, pkgInterfaces map[string]interfaceInfo) error {
findService := func(serviceName string) (service *model.Service, ok bool) {
for _, service := range services {
if service.Name == serviceName {
Expand Down Expand Up @@ -205,20 +221,28 @@ func readServiceFile(file *ast.File, packageName string, services model.ServiceL
if iSpec, ok := typeSpec.Type.(*ast.InterfaceType); ok {
service.IsInterface = true

for _, fieldDecl := range iSpec.Methods.List {
if funcDecl, ok := fieldDecl.Type.(*ast.FuncType); ok {
if len(fieldDecl.Names) == 0 {
continue
}
resolved := resolveInterfaceMethods(iSpec, fileImports, pkgInterfaces, map[string]bool{ident.Name: true}, nil, nil)
for _, m := range resolved {
trace(" on sth:", m.name)

var tpNames []string
for k := range m.typeSubst {
tpNames = append(tpNames, k)
}

mname := fieldDecl.Names[0]
trace(" on sth:", mname.Name)
service.Methods = append(service.Methods, &model.Method{
Name: mname.Name,
Args: readFields(funcDecl.Params, fileImports),
Return: readFields(funcDecl.Results, fileImports),
})
args := readFields(m.funcTyp.Params, m.imports, tpNames...)

ret := readFields(m.funcTyp.Results, m.imports, tpNames...)
if len(m.typeSubst) > 0 {
substituteTypeParams(args, m.typeSubst, m.substImports)
substituteTypeParams(ret, m.typeSubst, m.substImports)
}

service.Methods = append(service.Methods, &model.Method{
Name: m.name,
Args: args,
Return: ret,
})
}
}
}
Expand All @@ -234,7 +258,7 @@ func readServiceFile(file *ast.File, packageName string, services model.ServiceL
return nil
}

func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields []*model.Field) {
func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap, typeParams ...string) (fields []*model.Field) {
trace("reading fields")

fields = []*model.Field{}
Expand All @@ -244,7 +268,7 @@ func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields
}

for _, param := range fieldList.List {
names, value, _ := readField(param, fileImports, nil)
names, value, _ := readField(param, fileImports, typeParams)
for _, name := range names {
fields = append(fields, &model.Field{
Name: name,
Expand All @@ -258,6 +282,47 @@ func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields
return
}

// substituteTypeParams replaces TypeParam entries in fields with concrete types from the substitution map.
// substImports are the imports needed to resolve the substitution expressions (may differ from the method's file imports).
func substituteTypeParams(fields []*model.Field, subst map[string]ast.Expr, substImports fileImportSpecMap) {
for _, f := range fields {
substituteValue(f.Value, subst, substImports)
}
}

// substituteValue recursively replaces TypeParam references with concrete types.
func substituteValue(v *model.Value, subst map[string]ast.Expr, substImports fileImportSpecMap) {
if v == nil {
return
}

if v.TypeParam != "" {
if expr, ok := subst[v.TypeParam]; ok {
wasPtr := v.IsPtr
*v = model.Value{}
v.IsPtr = wasPtr
loadValueExpr(v, expr, substImports, nil)
}

return
}

if v.Array != nil {
substituteValue(v.Array.Value, subst, substImports)
}

if v.Map != nil {
substituteValue(v.Map.Key, subst, substImports)
substituteValue(v.Map.Value, subst, substImports)
}

if v.StructType != nil {
for _, arg := range v.StructType.TypeArgs {
substituteValue(arg, subst, substImports)
}
}
}

func readServicesInPackage(pkg *ast.Package, packageName string, serviceMap map[string]string) (services model.ServiceList, err error) {
if pkg == nil {
return nil, errors.New("package cannot be nil")
Expand All @@ -272,6 +337,8 @@ func readServicesInPackage(pkg *ast.Package, packageName string, serviceMap map[
})
}

pkgInterfaces := collectPackageInterfaces(pkg, packageName)

pkgFiles := make([]string, 0, len(pkg.Files))
for k := range pkg.Files {
pkgFiles = append(pkgFiles, k)
Expand All @@ -282,7 +349,7 @@ func readServicesInPackage(pkg *ast.Package, packageName string, serviceMap map[
for _, k := range pkgFiles {
file := pkg.Files[k]

err = readServiceFile(file, packageName, services)
err = readServiceFile(file, packageName, services, pkgInterfaces)
if err != nil {
return
}
Expand Down Expand Up @@ -690,6 +757,10 @@ func getScalarForField(value *model.Value) []*model.Scalar {
switch {
case value.Scalar != nil:
scalarTypes = append(scalarTypes, value.Scalar)
case value.StructType != nil:
for _, arg := range value.StructType.TypeArgs {
scalarTypes = append(scalarTypes, getScalarForField(arg)...)
}
case value.Map != nil:
if value.Map.Key != nil {
if v := getScalarForField(value.Map.Key); v != nil {
Expand Down Expand Up @@ -744,3 +815,164 @@ func collectStructTypes(fields []*model.Field, structTypes map[string]bool) {
}
}
}

// collectPackageInterfaces scans all files in the package and builds a map
// of interface names to their AST and file imports.
func collectPackageInterfaces(pkg *ast.Package, packageName string) map[string]interfaceInfo {
result := map[string]interfaceInfo{}

for _, file := range pkg.Files {
fileImports := getFileImports(file, packageName)
for _, decl := range file.Decls {
genDecl, ok := decl.(*ast.GenDecl)
if !ok || genDecl.Tok != token.TYPE {
continue
}

for _, spec := range genDecl.Specs {
typeSpec, ok := spec.(*ast.TypeSpec)
if !ok {
continue
}

iface, ok := typeSpec.Type.(*ast.InterfaceType)
if !ok {
continue
}

var typeParams []string

if typeSpec.TypeParams != nil {
for _, tp := range typeSpec.TypeParams.List {
for _, n := range tp.Names {
typeParams = append(typeParams, n.Name)
}
}
}

result[typeSpec.Name.Name] = interfaceInfo{
iface: iface,
imports: fileImports,
typeParams: typeParams,
}
}
}
}

return result
}

// resolveExpr resolves an AST expression through a type substitution map.
func resolveExpr(expr ast.Expr, typeSubst map[string]ast.Expr) ast.Expr {
if ident, ok := expr.(*ast.Ident); ok {
if sub, ok := typeSubst[ident.Name]; ok {
return sub
}
}

return expr
}

// resolveInterfaceMethods recursively collects all methods from an interface,
// following embedded interfaces via the pkgInterfaces map. Uses visited for cycle protection.
// typeSubst maps type parameter names to concrete type expressions.
// substImports are the imports needed to resolve expressions in typeSubst.
func resolveInterfaceMethods(iface *ast.InterfaceType, imports fileImportSpecMap, pkgInterfaces map[string]interfaceInfo, visited map[string]bool, typeSubst map[string]ast.Expr, substImports fileImportSpecMap) []resolvedMethod {
var methods []resolvedMethod

for _, field := range iface.Methods.List {
switch ft := field.Type.(type) {
case *ast.FuncType:
if len(field.Names) == 0 {
continue
}

methods = append(methods, resolvedMethod{
name: field.Names[0].Name,
funcTyp: ft,
imports: imports,
typeSubst: typeSubst,
substImports: substImports,
})
case *ast.Ident:
// Embedded interface reference (non-generic)
if visited[ft.Name] {
continue
}

visited[ft.Name] = true
if info, ok := pkgInterfaces[ft.Name]; ok {
methods = append(methods, resolveInterfaceMethods(info.iface, info.imports, pkgInterfaces, visited, nil, nil)...)
}
case *ast.IndexExpr:
// Generic embedded interface with single type arg: Base[string] or Base[T]
ident, ok := ft.X.(*ast.Ident)
if !ok {
continue
}

if visited[ident.Name] {
continue
}

visited[ident.Name] = true

info, ok := pkgInterfaces[ident.Name]
if !ok {
continue
}
// Build substitution map for the embedded interface's type params.
// Determine the imports needed to resolve the substitution expressions:
// if the arg was resolved from the parent's typeSubst, use substImports;
// otherwise use the current imports (where the embedding is written).
newSubst := map[string]ast.Expr{}
newSubstImports := imports

resolvedArg := resolveExpr(ft.Index, typeSubst)
if resolvedArg != ft.Index && substImports != nil {
newSubstImports = substImports
}

if len(info.typeParams) > 0 {
newSubst[info.typeParams[0]] = resolvedArg
}

methods = append(methods, resolveInterfaceMethods(info.iface, info.imports, pkgInterfaces, visited, newSubst, newSubstImports)...)
case *ast.IndexListExpr:
// Generic embedded interface with multiple type args: Keyed[string, int]
ident, ok := ft.X.(*ast.Ident)
if !ok {
continue
}

if visited[ident.Name] {
continue
}

visited[ident.Name] = true

info, ok := pkgInterfaces[ident.Name]
if !ok {
continue
}

newSubst := map[string]ast.Expr{}
newSubstImports := imports

for i, idx := range ft.Indices {
resolvedArg := resolveExpr(idx, typeSubst)
if resolvedArg != idx && substImports != nil {
newSubstImports = substImports
}

if i < len(info.typeParams) {
newSubst[info.typeParams[i]] = resolvedArg
}
}

methods = append(methods, resolveInterfaceMethods(info.iface, info.imports, pkgInterfaces, visited, newSubst, newSubstImports)...)
}
}

return methods
}
18 changes: 18 additions & 0 deletions internal/parser/typereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,15 @@ func loadValueExpr(v *model.Value, expr ast.Expr, fileImports fileImportSpecMap,
case *ast.IndexExpr:
// Generic type with single type argument: T[X]
loadValueExpr(v, exprType.X, fileImports, typeParams)
// Cross-package types have ident.Obj==nil, so readAstType may
// create a Scalar instead of StructType. Promote for generics.
if v.StructType == nil && v.Scalar != nil {
v.StructType = &model.StructType{
Name: v.Scalar.Name,
Package: v.Scalar.Package,
}
v.Scalar = nil
}

if v.StructType != nil {
arg := &model.Value{}
Expand All @@ -310,6 +319,15 @@ func loadValueExpr(v *model.Value, expr ast.Expr, fileImports fileImportSpecMap,
case *ast.IndexListExpr:
// Generic type with multiple type arguments: T[X, Y]
loadValueExpr(v, exprType.X, fileImports, typeParams)
// Cross-package types have ident.Obj==nil, so readAstType may
// create a Scalar instead of StructType. Promote for generics.
if v.StructType == nil && v.Scalar != nil {
v.StructType = &model.StructType{
Name: v.Scalar.Name,
Package: v.Scalar.Package,
}
v.Scalar = nil
}

if v.StructType != nil {
for _, index := range exprType.Indices {
Expand Down
2 changes: 1 addition & 1 deletion tests/aliases/client/client.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ test("registryValue", async () => {
});

test("indexValue", async () => {
const v = { [Category.A]: [{ id: "1", status: Status.Active, priority: Priority.Low, rating: 1.0 }], [Category.B]: null };
const v = { [Category.A]: [{ id: "1", status: Status.Active, priority: Priority.Low, rating: 1.0 }] };
const ret = await client.indexValue(v);
expect(ret![Category.A]).toHaveLength(1);
expect(ret![Category.A]![0]!.id).toBe("1");
Expand Down
Loading
Loading