Skip to content

Commit 73fc47f

Browse files
lordspinachInsei
authored andcommitted
add auto mapping with manual routing ability and auto mapping nested structs
Signed-off-by: Konstantin Gamayunov <ksgamayunov@gmail.com>
1 parent 4f4816e commit 73fc47f

File tree

5 files changed

+575
-0
lines changed

5 files changed

+575
-0
lines changed

auto.go

Lines changed: 92 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
package gomapper
2+
3+
import (
4+
"github.com/insei/gomapper/fields"
5+
"reflect"
6+
"strings"
7+
)
8+
9+
var manualFieldRoutes = map[reflect.Type]map[string]string{}
10+
11+
func AutoRoute[TSource, TDest any | []any](options ...AutoMapperOption) error {
12+
s := new(TSource)
13+
d := new(TDest)
14+
sourceFields := fields.GetFrom(s)
15+
destFields := fields.GetFrom(d)
16+
sourceType := reflect.TypeOf(s)
17+
18+
parseOptions(options, sourceType)
19+
20+
mapFunc := func(source TSource, dest *TDest) error {
21+
for key, sourceFld := range sourceFields {
22+
destFld, ok := destFields[getDestFieldName(sourceType, key)]
23+
if !ok || strings.Contains(key, ".") {
24+
continue
25+
}
26+
if err := setFieldRecursive(sourceFld, destFld, source, dest); err != nil {
27+
return err
28+
}
29+
}
30+
return nil
31+
}
32+
return AddRoute[TSource, TDest](mapFunc)
33+
}
34+
35+
func parseOptions(options []AutoMapperOption, sourceType reflect.Type) {
36+
for _, option := range options {
37+
switch autoMapperOption := option.(type) {
38+
case fieldPathOption:
39+
if manualFieldRoutes[sourceType] == nil {
40+
manualFieldRoutes[sourceType] = map[string]string{}
41+
}
42+
manualFieldRoutes[sourceType][autoMapperOption.source] = autoMapperOption.dest
43+
}
44+
}
45+
}
46+
47+
func setFieldRecursive(sourceFld, destFld fields.Field, source, dest any) error {
48+
if sourceFld.Type.Kind() != reflect.Struct {
49+
sourceVal := sourceFld.Get(source)
50+
if sourceVal != nil {
51+
destFld.Set(dest, sourceVal)
52+
}
53+
return nil
54+
}
55+
56+
if r, ok := getRouteIfExists(sourceFld, destFld); ok {
57+
return r(sourceFld.Get(source), destFld.Get(dest))
58+
}
59+
60+
sourceStructField := sourceFld.Get(source)
61+
sourceFields := fields.GetFrom(sourceStructField)
62+
destStructField := destFld.Get(dest)
63+
destFields := fields.GetFrom(destStructField)
64+
65+
for fieldName, sField := range sourceFields {
66+
dField, ok := destFields[getDestFieldName(sField.Type, fieldName)]
67+
if !ok || strings.Contains(fieldName, ".") {
68+
continue
69+
}
70+
err := setFieldRecursive(sField, dField, sourceStructField, destStructField)
71+
if err != nil {
72+
return err
73+
}
74+
}
75+
return nil
76+
}
77+
78+
func getRouteIfExists(sourceFld, destFld fields.Field) (func(source interface{}, dest interface{}) error, bool) {
79+
destType := destFld.Type
80+
if destType.Kind() != reflect.Ptr {
81+
destType = reflect.PointerTo(destType)
82+
}
83+
r, ok := routes[sourceFld.Type][destType]
84+
return r, ok
85+
}
86+
87+
func getDestFieldName(sourceFieldType reflect.Type, sourceFieldName string) string {
88+
if destFieldName, ok := manualFieldRoutes[sourceFieldType][sourceFieldName]; ok {
89+
return destFieldName
90+
}
91+
return sourceFieldName
92+
}

auto_test.go

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
package gomapper
2+
3+
import (
4+
"github.com/stretchr/testify/assert"
5+
"testing"
6+
)
7+
8+
type AutoMappingStructSource struct {
9+
Name string
10+
NestedStruct NestedStructSource
11+
}
12+
13+
type AutoMappingStructDest struct {
14+
Name string
15+
SecondName string
16+
NestedStruct NestedStructDest
17+
}
18+
19+
type NestedStructSource struct {
20+
FirstNestedName string
21+
DeepNestedStruct DeepNestedStructSource
22+
}
23+
24+
type NestedStructDest struct {
25+
FirstNestedName string
26+
FirstNestedSecondName string
27+
DeepNestedStruct DeepNestedStructDest
28+
}
29+
30+
type DeepNestedStructSource struct {
31+
SecondNestedName string
32+
}
33+
34+
type DeepNestedStructDest struct {
35+
SecondNestedName string
36+
}
37+
38+
func TestAutoRoute(t *testing.T) {
39+
_ = AutoRoute[AutoMappingStructSource, AutoMappingStructDest]()
40+
t.Run("Auto route without options", func(t *testing.T) {
41+
source := &AutoMappingStructSource{Name: "Test1"}
42+
dest, err := MapTo[AutoMappingStructDest](source)
43+
assert.NoError(t, err)
44+
assert.Equal(t, source.Name, dest.Name)
45+
})
46+
_ = AutoRoute[AutoMappingStructSource, AutoMappingStructDest](WithFieldRoute("Name", "SecondName"))
47+
t.Run("Auto route with options", func(t *testing.T) {
48+
source := &AutoMappingStructSource{Name: "Test1"}
49+
dest, err := MapTo[AutoMappingStructDest](source)
50+
assert.NoError(t, err)
51+
assert.Equal(t, "", dest.Name)
52+
assert.Equal(t, source.Name, dest.SecondName)
53+
})
54+
t.Run("Auto mapping struct fields", func(t *testing.T) {
55+
source := &AutoMappingStructSource{
56+
NestedStruct: NestedStructSource{
57+
FirstNestedName: "Test1",
58+
DeepNestedStruct: DeepNestedStructSource{
59+
SecondNestedName: "Test2",
60+
},
61+
},
62+
}
63+
dest, err := MapTo[AutoMappingStructDest](source)
64+
assert.NoError(t, err)
65+
assert.Equal(t, source.NestedStruct.FirstNestedName, dest.NestedStruct.FirstNestedName)
66+
assert.Equal(t, source.NestedStruct.DeepNestedStruct.SecondNestedName, dest.NestedStruct.DeepNestedStruct.SecondNestedName)
67+
})
68+
69+
_ = AddRoute[NestedStructSource, NestedStructDest](func(source NestedStructSource, dest *NestedStructDest) error {
70+
dest.FirstNestedSecondName = source.FirstNestedName
71+
return nil
72+
})
73+
t.Run("Auto mapping using existing route on nested struct", func(t *testing.T) {
74+
source := &AutoMappingStructSource{
75+
NestedStruct: NestedStructSource{
76+
FirstNestedName: "Test1",
77+
},
78+
}
79+
dest, err := MapTo[AutoMappingStructDest](source)
80+
assert.NoError(t, err)
81+
assert.Equal(t, "", dest.NestedStruct.FirstNestedName)
82+
assert.Equal(t, source.NestedStruct.FirstNestedName, dest.NestedStruct.FirstNestedSecondName)
83+
})
84+
}

fields/fields.go

Lines changed: 207 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,207 @@
1+
package fields
2+
3+
import (
4+
"reflect"
5+
"unsafe"
6+
)
7+
8+
type Field reflect.StructField
9+
10+
var fields = map[reflect.Type]map[string]Field{}
11+
12+
// Get returns the value of the fields in the provided object.
13+
// It takes a parameter `obj` of type `interface{}`, representing the object.
14+
// It returns the value of the fields as an `interface{}`.
15+
// If the fields type is `string`, it returns the value as a `string`.
16+
// If the fields type is `int`, it returns the value as an `int`.
17+
// If the fields type is `bool`, it returns the value as a `bool`.
18+
// If the fields type is not handled, it panics with an error message.
19+
func (f Field) Get(obj interface{}) interface{} {
20+
ptrToField := f.getPtr(obj)
21+
kind := f.Type.Kind()
22+
isPtr := false
23+
if kind == reflect.Ptr {
24+
isPtr = true
25+
kind = f.Type.Elem().Kind()
26+
}
27+
if isPtr {
28+
switch kind {
29+
case reflect.String:
30+
return getPtrValue[*string](ptrToField)
31+
case reflect.Int:
32+
return getPtrValue[*int](ptrToField)
33+
case reflect.Int8:
34+
return getPtrValue[*int8](ptrToField)
35+
case reflect.Int16:
36+
return getPtrValue[*int16](ptrToField)
37+
case reflect.Int32:
38+
return getPtrValue[*int32](ptrToField)
39+
case reflect.Int64:
40+
return getPtrValue[*int64](ptrToField)
41+
case reflect.Float32:
42+
return getPtrValue[*float32](ptrToField)
43+
case reflect.Float64:
44+
return getPtrValue[*float64](ptrToField)
45+
case reflect.Bool:
46+
return getPtrValue[*bool](ptrToField)
47+
case reflect.Struct:
48+
return reflect.NewAt(f.Type, ptrToField).Interface()
49+
default:
50+
panic("unhandled default case")
51+
}
52+
} else {
53+
switch kind {
54+
case reflect.String:
55+
return getPtrValue[string](ptrToField)
56+
case reflect.Int:
57+
return getPtrValue[int](ptrToField)
58+
case reflect.Int8:
59+
return getPtrValue[int8](ptrToField)
60+
case reflect.Int16:
61+
return getPtrValue[int16](ptrToField)
62+
case reflect.Int32:
63+
return getPtrValue[int32](ptrToField)
64+
case reflect.Int64:
65+
return getPtrValue[int64](ptrToField)
66+
case reflect.Float32:
67+
return getPtrValue[float32](ptrToField)
68+
case reflect.Float64:
69+
return getPtrValue[float64](ptrToField)
70+
case reflect.Bool:
71+
return getPtrValue[bool](ptrToField)
72+
case reflect.Struct:
73+
return reflect.NewAt(f.Type, ptrToField).Interface()
74+
default:
75+
panic("unhandled default case")
76+
}
77+
}
78+
}
79+
80+
// getPtr returns a pointer to the field's value in the provided configuration object.
81+
// It takes a parameter `conf` of type `any`, representing the configuration object.
82+
// It returns an `unsafe.Pointer` to the `field's` value in the configuration object.
83+
func (f Field) getPtr(conf interface{}) unsafe.Pointer {
84+
confPointer := ((*[2]unsafe.Pointer)(unsafe.Pointer(&conf)))[1]
85+
ptToField := unsafe.Add(confPointer, f.Offset)
86+
return ptToField
87+
}
88+
89+
func setPtrValue[T any](ptr unsafe.Pointer, val any) {
90+
valSet := (*T)(ptr)
91+
*valSet = val.(T)
92+
}
93+
94+
func getPtrValue[T any](ptr unsafe.Pointer) T {
95+
return *(*T)(ptr)
96+
}
97+
98+
// Set updates the value of the fields in the provided object with the provided value.
99+
// It takes two parameters:
100+
// - obj: interface{}, representing the object containing the fields.
101+
// - val: interface{}, representing the new value for the fields.
102+
//
103+
// The Set method uses the getPtr method to get a pointer to the fields in the object.
104+
// It then performs a type switch on the kind of the fields to determine its type, and sets the value accordingly.
105+
// The supported fields types are string, int, and bool.
106+
// If the fields type is not one of the supported types, it panics with the message "unhandled default case".
107+
func (f Field) Set(obj interface{}, val interface{}) {
108+
ptrToField := f.getPtr(obj)
109+
kind := f.Type.Kind()
110+
isPtr := false
111+
if kind == reflect.Ptr {
112+
isPtr = true
113+
kind = f.Type.Elem().Kind()
114+
}
115+
if isPtr {
116+
switch kind {
117+
case reflect.String:
118+
setPtrValue[*string](ptrToField, val)
119+
case reflect.Int:
120+
setPtrValue[*int](ptrToField, val)
121+
case reflect.Int8:
122+
setPtrValue[*int8](ptrToField, val)
123+
case reflect.Int16:
124+
setPtrValue[*int16](ptrToField, val)
125+
case reflect.Int32:
126+
setPtrValue[*int32](ptrToField, val)
127+
case reflect.Int64:
128+
setPtrValue[*int64](ptrToField, val)
129+
case reflect.Float32:
130+
setPtrValue[*float32](ptrToField, val)
131+
case reflect.Float64:
132+
setPtrValue[*float64](ptrToField, val)
133+
case reflect.Bool:
134+
setPtrValue[*bool](ptrToField, val)
135+
default:
136+
panic("unhandled default case")
137+
}
138+
} else {
139+
switch kind {
140+
case reflect.String:
141+
setPtrValue[string](ptrToField, val)
142+
case reflect.Int:
143+
setPtrValue[int](ptrToField, val)
144+
case reflect.Int8:
145+
setPtrValue[int8](ptrToField, val)
146+
case reflect.Int16:
147+
setPtrValue[int16](ptrToField, val)
148+
case reflect.Int32:
149+
setPtrValue[int32](ptrToField, val)
150+
case reflect.Int64:
151+
setPtrValue[int64](ptrToField, val)
152+
case reflect.Float32:
153+
setPtrValue[float32](ptrToField, val)
154+
case reflect.Float64:
155+
setPtrValue[float64](ptrToField, val)
156+
case reflect.Bool:
157+
setPtrValue[bool](ptrToField, val)
158+
default:
159+
panic("unhandled default case")
160+
}
161+
}
162+
}
163+
164+
// Get returns a map of FieldGetSet objects representing the fields of the provided object.
165+
// It takes a parameter `conf` of type `any`, representing the object.
166+
// It returns a map with string keys (fields path) and FieldGetSet values.
167+
// FieldGetSet is an interface that defines Get and Set methods for fields.
168+
func Get[T interface{}]() map[string]Field {
169+
obj := new(T)
170+
return GetFrom(obj)
171+
}
172+
173+
func GetFrom(obj interface{}) map[string]Field {
174+
typeOf := reflect.TypeOf(obj)
175+
if tFields, ok := fields[typeOf]; ok {
176+
return tFields
177+
}
178+
tFields := map[string]Field{}
179+
getFieldsMapRecursive(obj, "", &tFields)
180+
fields[typeOf] = tFields
181+
return tFields
182+
}
183+
184+
func getFieldsMapRecursive(conf any, path string, f *map[string]Field) {
185+
typeOf := reflect.TypeOf(conf)
186+
valueOf := reflect.ValueOf(conf)
187+
if reflect.ValueOf(conf).Kind() == reflect.Ptr {
188+
typeOf = typeOf.Elem()
189+
valueOf = valueOf.Elem()
190+
}
191+
if path != "" {
192+
path += "."
193+
}
194+
for i := 0; i < typeOf.NumField(); i++ {
195+
fieldTypeOf := typeOf.Field(i)
196+
fieldValueOf := valueOf.Field(i)
197+
switch fieldTypeOf.Type.Kind() {
198+
case reflect.Slice:
199+
break
200+
case reflect.Struct:
201+
(*f)[path+fieldTypeOf.Name] = Field(fieldTypeOf)
202+
getFieldsMapRecursive(fieldValueOf.Addr().Interface(), path+fieldTypeOf.Name, f)
203+
default:
204+
(*f)[path+fieldTypeOf.Name] = Field(fieldTypeOf)
205+
}
206+
}
207+
}

0 commit comments

Comments
 (0)