@@ -6,7 +6,10 @@ import (
66 "encoding/json"
77 "fmt"
88 "io"
9+ "mime"
910 "net/http"
11+
12+ "github.com/Rocket-Rescue-Node/guarded-beacon-proxy/ssz"
1013)
1114
1215// HTTPAuthenticator is a function type which can authenticate HTTP requests.
@@ -22,6 +25,9 @@ import (
2225// information.
2326type HTTPAuthenticator func (* http.Request ) (AuthenticationStatus , context.Context , error )
2427
28+ // If true is returned, the upstream will proxy the request.
29+ type httpGuard func (w http.ResponseWriter , r * http.Request ) bool
30+
2531func (gbp * GuardedBeaconProxy ) authenticationMiddleware (next http.Handler ) http.Handler {
2632 return http .HandlerFunc (func (w http.ResponseWriter , r * http.Request ) {
2733 status , context , err := gbp .HTTPAuthenticator (r )
@@ -40,16 +46,13 @@ func (gbp *GuardedBeaconProxy) authenticationMiddleware(next http.Handler) http.
4046}
4147
4248func cloneRequestBody (r * http.Request ) (io.ReadCloser , error ) {
43- // Read the body
44- buf , err := io . ReadAll ( r . Body )
45- if err != nil {
46- return nil , err
47- }
49+ // Use an io.TeeReader to return a reader that re-writes the body to the original request body.
50+ buf := bytes . NewBuffer ( nil )
51+ tee := io . TeeReader ( r . Body , buf )
52+ out := io . NopCloser ( tee )
53+ r . Body = io . NopCloser ( buf )
4854
49- original := io .NopCloser (bytes .NewBuffer (buf ))
50- clone := io .NopCloser (bytes .NewBuffer (buf ))
51- r .Body = original
52- return clone , nil
55+ return out , nil
5356}
5457
5558func (gbp * GuardedBeaconProxy ) httpError (w http.ResponseWriter , code int , err error ) {
@@ -60,46 +63,64 @@ func (gbp *GuardedBeaconProxy) httpError(w http.ResponseWriter, code int, err er
6063 }
6164}
6265
63- func (gbp * GuardedBeaconProxy ) prepareBeaconProposer (w http.ResponseWriter , r * http.Request ) {
64- buf , err := cloneRequestBody (r )
66+ func (gbp * GuardedBeaconProxy ) prepareBeaconProposer (w http.ResponseWriter , r * http.Request ) bool {
67+ reader , err := cloneRequestBody (r )
6568 if err != nil {
6669 gbp .httpError (w , http .StatusInternalServerError , nil )
67- return
70+ return false
6871 }
6972
7073 var proposers PrepareBeaconProposerRequest
71- if err := json .NewDecoder (buf ).Decode (& proposers ); err != nil {
74+ if err := json .NewDecoder (reader ).Decode (& proposers ); err != nil {
7275 gbp .httpError (w , http .StatusBadRequest , nil )
73- return
76+ return false
7477 }
7578
7679 status , err := gbp .PrepareBeaconProposerGuard (proposers , r .Context ())
7780 if status != Allowed {
7881 gbp .httpError (w , status .httpStatus (), err )
79- return
82+ return false
8083 }
8184
82- gbp . proxy . ServeHTTP ( w , r )
85+ return true
8386}
8487
85- func (gbp * GuardedBeaconProxy ) registerValidator (w http.ResponseWriter , r * http.Request ) {
86- buf , err := cloneRequestBody (r )
88+ func (gbp * GuardedBeaconProxy ) registerValidator (w http.ResponseWriter , r * http.Request ) bool {
89+ reader , err := cloneRequestBody (r )
8790 if err != nil {
8891 gbp .httpError (w , http .StatusInternalServerError , nil )
89- return
92+ return false
93+ }
94+
95+ // Check the content-type header
96+ contentType , _ , err := mime .ParseMediaType (r .Header .Get ("Content-Type" ))
97+ if err != nil {
98+ gbp .httpError (w , http .StatusUnsupportedMediaType , err )
99+ return false
90100 }
91101
92102 var validators RegisterValidatorRequest
93- if err := json .NewDecoder (buf ).Decode (& validators ); err != nil {
94- gbp .httpError (w , http .StatusBadRequest , nil )
95- return
103+ switch contentType {
104+ case "application/json" :
105+ if err := json .NewDecoder (reader ).Decode (& validators ); err != nil {
106+ gbp .httpError (w , http .StatusBadRequest , err )
107+ return false
108+ }
109+ case "application/octet-stream" :
110+ if err , status := ssz .ToRegisterValidatorRequest (& validators , reader , gbp .MaxRequestBodySize ); err != nil {
111+ gbp .httpError (w , status , err )
112+ return false
113+ }
114+ default :
115+ gbp .httpError (w , http .StatusUnsupportedMediaType , fmt .Errorf ("unsupported content type: %s" , contentType ))
116+ return false
96117 }
97118
98119 status , err := gbp .RegisterValidatorGuard (validators , r .Context ())
99120 if status != Allowed {
100121 gbp .httpError (w , status .httpStatus (), err )
101- return
122+ return false
102123 }
103124
104- gbp . proxy . ServeHTTP ( w , r )
125+ return true
105126}
0 commit comments