Skip to content
Draft
7 changes: 5 additions & 2 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ tmp/
## Docker
!.dockerignore

## Node ##
### Node
!.npmrc
**/node_modules/

## TypeScript
Expand All @@ -39,5 +40,7 @@ tmp/
!.lefthook.yaml

## Vitepress
node_modules/
!docs/.vitepress
docs/.vitepress/*
!docs/.vitepress/theme/
!docs/.vitepress/config.mts
6 changes: 3 additions & 3 deletions .mise.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[tools]
# https://mise-tools.jdx.dev/tools/bun
bun = "1.3.10"
bun = "1.3.11"
# https://mise-tools.jdx.dev/tools/lefthook
lefthook = "2.1.1"
lefthook = "2.1.4"
# https://mise-tools.jdx.dev/tools/golangci-lint
golangci-lint = "2.10.1"
golangci-lint = "2.11.3"
2 changes: 1 addition & 1 deletion bufferedclient.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ func (c *bufferedClient) Call(ctx context.Context, url string, endpoint string,
return NewClientError(errors.Wrap(errRequest, "failed to create request"))
}

resp, errDo := c.client.Do(request) //nolint:gosec
resp, errDo := c.client.Do(request)
if errDo != nil {
return NewClientError(errors.Wrap(errDo, "failed to send request"))
}
Expand Down
2 changes: 1 addition & 1 deletion instrumentation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func TestInstrumentedService(t *testing.T) {
})

rsp := httptest.NewRecorder()
req := httptest.NewRequest(http.MethodGet, "/test", nil)
req := httptest.NewRequestWithContext(t.Context(), http.MethodGet, "/test", nil)

handler(rsp, req)

Expand Down
19 changes: 19 additions & 0 deletions internal/codegen/gocode.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,12 +34,27 @@ func valueGoType(v *model.Value, aliases map[string]string, packageName string)
t += "[]" + valueGoType(v.Array.Value, aliases, packageName)
case len(v.GoScalarType) > 0:
t += v.GoScalarType
case v.TypeParam != "":
t += v.TypeParam
case v.StructType != nil:
if packageName != v.StructType.Package && aliases[v.StructType.Package] != "" {
t += aliases[v.StructType.Package] + "."
}

t += v.StructType.Name
if len(v.StructType.TypeArgs) > 0 {
t += "["

for i, arg := range v.StructType.TypeArgs {
if i > 0 {
t += ", "
}

t += valueGoType(arg, aliases, packageName)
}

t += "]"
}
case v.Map != nil:
t += `map[` + valueGoType(v.Map.Key, aliases, packageName) + `]` + valueGoType(v.Map.Value, aliases, packageName)
case v.Scalar != nil:
Expand Down Expand Up @@ -111,6 +126,10 @@ func extractImportValue(value *model.Value, fullPackageName string, aliases map[
switch {
case value.StructType != nil:
extractImport(value.StructType.Package, fullPackageName, aliases)

for _, arg := range value.StructType.TypeArgs {
extractImportValue(arg, fullPackageName, aliases)
}
case value.Array != nil:
extractImportValue(value.Array.Value, fullPackageName, aliases)
case value.Map != nil:
Expand Down
90 changes: 81 additions & 9 deletions internal/codegen/typescript.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,27 @@ var tsTypeAliases = map[string]string{
"time.Time": "number",
}

func renderTSTypeArgs(typeArgs []*model.Value, mappings config.TypeScriptMappings, scalars map[string]*model.Scalar, structs map[string]*model.Struct, ts *Code) {
if len(typeArgs) > 0 {
ts.App("<")

for i, arg := range typeArgs {
if i > 0 {
ts.App(",")
}

valueTSType(arg, mappings, scalars, structs, ts, nil)
}

ts.App(">")
}
}

func valueTSType(v *model.Value, mappings config.TypeScriptMappings, scalars map[string]*model.Scalar, structs map[string]*model.Struct, ts *Code, jsonInfo *model.JSONInfo) {
switch {
case v.TypeParam != "":
ts.App(v.TypeParam)
return
case jsonInfo != nil && len(jsonInfo.Type) > 0:
ts.App(jsonInfo.Type)
case v.Map != nil:
Expand Down Expand Up @@ -102,6 +121,7 @@ func valueTSType(v *model.Value, mappings config.TypeScriptMappings, scalars map
}

ts.App(tsModule + "." + v.StructType.Name)
renderTSTypeArgs(v.StructType.TypeArgs, mappings, scalars, structs, ts)

hiddenStruct, isHiddenStruct := structs[v.StructType.FullName()]
if isHiddenStruct && (hiddenStruct.Array != nil || hiddenStruct.Map != nil) && (jsonInfo == nil || !jsonInfo.OmitEmpty) {
Expand All @@ -114,6 +134,7 @@ func valueTSType(v *model.Value, mappings config.TypeScriptMappings, scalars map
}

ts.App(v.StructType.Name)
renderTSTypeArgs(v.StructType.TypeArgs, mappings, scalars, structs, ts)
case v.Struct != nil:
ts.L("{").Ind(1)
renderStructFields(v.Struct.Fields, mappings, scalars, structs, ts)
Expand Down Expand Up @@ -167,15 +188,66 @@ func renderStructFields(fields []*model.Field, mappings config.TypeScriptMapping
}
}

// mapKeyTypeParams returns the set of type parameter names used as map keys
// in the given struct's fields or as the struct-level map key.
func mapKeyTypeParams(str *model.Struct) map[string]bool {
result := make(map[string]bool)
if str.Map != nil && str.Map.Key != nil && str.Map.Key.TypeParam != "" {
result[str.Map.Key.TypeParam] = true
}

for _, f := range str.Fields {
collectMapKeyTypeParams(f.Value, result)
}

return result
}

func collectMapKeyTypeParams(v *model.Value, result map[string]bool) {
if v == nil {
return
}

if v.Map != nil {
if v.Map.Key != nil && v.Map.Key.TypeParam != "" {
result[v.Map.Key.TypeParam] = true
}

collectMapKeyTypeParams(v.Map.Value, result)
}

if v.Array != nil {
collectMapKeyTypeParams(v.Array.Value, result)
}
}

func tsTypeParamsSuffix(typeParams []string, mapKeys map[string]bool) string {
if len(typeParams) == 0 {
return ""
}

parts := make([]string, len(typeParams))
for i, tp := range typeParams {
if mapKeys[tp] {
parts[i] = tp + " extends string | number | symbol"
} else {
parts[i] = tp
}
}

return "<" + strings.Join(parts, ", ") + ">"
}

func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMappings, scalars map[string]*model.Scalar, structs map[string]*model.Struct, ts *Code) error {
ts.L("// " + str.FullName())
typeParamsSuffix := tsTypeParamsSuffix(str.TypeParams, mapKeyTypeParams(str))

switch {
case str.Array != nil:
if str.Array.Len > 0 && str.Array.Value.ScalarType == model.ScalarTypeByte {
ts.App("export type " + str.Name + " = Uint8Array & { readonly length: " + fmt.Sprintf("%d", str.Array.Len) + " }")
ts.App("export type " + str.Name + typeParamsSuffix + " = Uint8Array & { readonly length: " + fmt.Sprintf("%d", str.Array.Len) + " }")
} else {
ts.App("export type " + str.Name + " = Array<")
ts.App("export type " + str.Name + typeParamsSuffix + " = Array<")
valueTSType(str.Array.Value, mappings, scalars, structs, ts, nil)
ts.App(">")
}
Expand All @@ -184,9 +256,9 @@ func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMapping
case str.Map != nil:
enumKey := str.Map.Key != nil && str.Map.Key.Scalar != nil
if enumKey {
ts.App("export type " + str.Name + " = Partial<Record<")
ts.App("export type " + str.Name + typeParamsSuffix + " = Partial<Record<")
} else {
ts.App("export type " + str.Name + " = Record<")
ts.App("export type " + str.Name + typeParamsSuffix + " = Record<")
}

if str.Map.Key != nil {
Expand All @@ -213,7 +285,7 @@ func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMapping

switch {
case str.UnionFields[0].Value.StructType != nil:
ts.App("export type " + str.Name + " = ")
ts.App("export type " + str.Name + typeParamsSuffix + " = ")

var isUndefined bool

Expand All @@ -235,7 +307,7 @@ func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMapping

ts.NL()
case str.UnionFields[0].Value.Scalar != nil:
ts.App("export const " + str.Name + " = ")
ts.App("export const " + str.Name + typeParamsSuffix + " = ")
ts.App("{ ")

for i, field := range str.UnionFields {
Expand All @@ -249,7 +321,7 @@ func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMapping

ts.App(" }")
ts.NL()
ts.App("export type " + str.Name + " = ")
ts.App("export type " + str.Name + typeParamsSuffix + " = ")

for i, field := range str.UnionFields {
if i > 0 {
Expand All @@ -266,7 +338,7 @@ func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMapping
case len(str.InlineFields) > 0:
var extends bool

ts.App("export interface " + str.Name)
ts.App("export interface " + str.Name + typeParamsSuffix)

for i, inlineField := range str.InlineFields {
if inlineField.Value.Scalar != nil {
Expand Down Expand Up @@ -330,7 +402,7 @@ func renderTypescriptStruct(str *model.Struct, mappings config.TypeScriptMapping
renderStructFields(str.Fields, mappings, scalars, structs, ts)
ts.Ind(-1).L("}")
default:
ts.L("export interface " + str.Name + " {").Ind(1)
ts.L("export interface " + str.Name + typeParamsSuffix + " {").Ind(1)
renderStructFields(str.Fields, mappings, scalars, structs, ts)
ts.Ind(-1).L("}")
}
Expand Down
1 change: 1 addition & 0 deletions internal/model/struct.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ type Struct struct {
InlineFields []*Field
Map *Map
Array *Array
TypeParams []string
}

func (s *Struct) FullName() string {
Expand Down
5 changes: 3 additions & 2 deletions internal/model/structtype.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
package model

type StructType struct {
Name string
Package string
Name string
Package string
TypeArgs []*Value
}

func (st *StructType) FullName() string {
Expand Down
1 change: 1 addition & 0 deletions internal/model/value.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ type Value struct {
Map *Map `json:",omitempty"`
Array *Array `json:",omitempty"`
IsPtr bool `json:",omitempty"`
TypeParam string `json:",omitempty"`
}
52 changes: 34 additions & 18 deletions internal/parser/servicereader.go
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ func readFields(fieldList *ast.FieldList, fileImports fileImportSpecMap) (fields
}

for _, param := range fieldList.List {
names, value, _ := readField(param, fileImports)
names, value, _ := readField(param, fileImports, nil)
for _, name := range names {
fields = append(fields, &model.Field{
Name: name,
Expand Down Expand Up @@ -400,6 +400,12 @@ func loadFlatStructsValue(s *model.Value, flatStructs map[string]bool) {
if s.Scalar != nil {
flatStructs[s.Scalar.FullName()] = true
}

if s.StructType != nil {
for _, arg := range s.StructType.TypeArgs {
loadFlatStructsValue(arg, flatStructs)
}
}
}

func fixFieldStructs(fields []*model.Field, structs map[string]*model.Struct, scalars map[string]*model.Scalar) {
Expand Down Expand Up @@ -567,6 +573,12 @@ func needsWorkValue(value *model.Value, needsWork func(fullName string) bool) bo
if needsWork(value.StructType.FullName()) {
return true
}

for _, arg := range value.StructType.TypeArgs {
if needsWorkValue(arg, needsWork) {
return true
}
}
case value.Array != nil:
if needsWorkValue(value.Array.Value, needsWork) {
return true
Expand Down Expand Up @@ -654,19 +666,22 @@ func getTypesInPackage(goPaths []string, gomod config.Namespace, packageName str
return structs, scalars, nil
}

func getStructTypeForField(value *model.Value) *model.StructType {
var strType *model.StructType
func getStructTypesForField(value *model.Value) []*model.StructType {
var types []*model.StructType

switch {
case value.StructType != nil:
strType = value.StructType
types = append(types, value.StructType)
for _, arg := range value.StructType.TypeArgs {
types = append(types, getStructTypesForField(arg)...)
}
case value.Map != nil:
strType = getStructTypeForField(value.Map.Value)
types = append(types, getStructTypesForField(value.Map.Value)...)
case value.Array != nil:
strType = getStructTypeForField(value.Array.Value)
types = append(types, getStructTypesForField(value.Array.Value)...)
}

return strType
return types
}

func getScalarForField(value *model.Value) []*model.Scalar {
Expand Down Expand Up @@ -712,18 +727,19 @@ func collectScalarTypes(fields []*model.Field, scalarTypes map[string]bool) {

func collectStructTypes(fields []*model.Field, structTypes map[string]bool) {
for _, field := range fields {
strType := getStructTypeForField(field.Value)
if strType != nil {
fullName := strType.Package + "." + strType.Name
if len(strType.Package) == 0 {
fullName = strType.Name
}
for _, strType := range getStructTypesForField(field.Value) {
if strType != nil {
fullName := strType.Package + "." + strType.Name
if len(strType.Package) == 0 {
fullName = strType.Name
}

switch fullName {
case "error", "net/http.Request", "net/http.ResponseWriter", "context.Context":
continue
default:
structTypes[fullName] = true
switch fullName {
case "error", "net/http.Request", "net/http.ResponseWriter", "context.Context":
continue
default:
structTypes[fullName] = true
}
}
}
}
Expand Down
Loading
Loading