Skip to content

Commit 1ac836e

Browse files
committed
Refactor server handlers and add JSON helper functions
1 parent 1288b83 commit 1ac836e

File tree

3 files changed

+105
-93
lines changed

3 files changed

+105
-93
lines changed

image.go

Lines changed: 12 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -99,11 +99,7 @@ func createLayerMetadata(layerDir, diffID string, index int, imageConfig *ImageC
9999
layerJSON["parent"] = prevDiffID
100100
}
101101

102-
layerJSONBytes, err := json.Marshal(layerJSON)
103-
if err != nil {
104-
return fmt.Errorf("failed to marshal layer JSON: %w", err)
105-
}
106-
return os.WriteFile(filepath.Join(layerDir, "json"), layerJSONBytes, 0644)
102+
return marshalJSONToFile(layerJSON, layerDir, "json")
107103
}
108104

109105
// downloadAllLayers downloads all layers and returns their diff IDs
@@ -141,11 +137,7 @@ func createDockerManifest(ref ImageReference, configDigest string, layerPaths []
141137
},
142138
}
143139

144-
manifestJSONBytes, err := json.Marshal(manifestJSON)
145-
if err != nil {
146-
return fmt.Errorf("failed to marshal manifest JSON: %w", err)
147-
}
148-
return os.WriteFile(filepath.Join(tempDir, "manifest.json"), manifestJSONBytes, 0644)
140+
return marshalJSONToFile(manifestJSON, tempDir, "manifest.json")
149141
}
150142

151143
// createRepositoriesFile creates the repositories file for docker load
@@ -157,11 +149,7 @@ func createRepositoriesFile(ref ImageReference, layerPaths []string, tempDir str
157149
imageName: {ref.Tag: topLayer},
158150
}
159151

160-
reposBytes, err := json.Marshal(repositories)
161-
if err != nil {
162-
return fmt.Errorf("failed to marshal repositories JSON: %w", err)
163-
}
164-
return os.WriteFile(filepath.Join(tempDir, "repositories"), reposBytes, 0644)
152+
return marshalJSONToFile(repositories, tempDir, "repositories")
165153
}
166154

167155
// createOutputTar creates the final tar archive
@@ -247,3 +235,12 @@ func GetImagePlatforms(imageRef string) ([]Platform, error) {
247235

248236
return client.GetPlatforms(ref)
249237
}
238+
239+
// marshalJSONToFile marshals v to JSON and writes it to dir/filename.
240+
func marshalJSONToFile(v interface{}, dir, filename string) error {
241+
data, err := json.Marshal(v)
242+
if err != nil {
243+
return fmt.Errorf("failed to marshal %s: %w", filename, err)
244+
}
245+
return os.WriteFile(filepath.Join(dir, filename), data, 0644)
246+
}

registry.go

Lines changed: 38 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ import (
1414

1515
const bearerPrefix = "Bearer "
1616
const responseBodyStr = "response body"
17+
const invalidImageReferenceFormat = "invalid image reference: %w"
1718

1819
// ImageReference represents a parsed Docker image reference
1920
type ImageReference struct {
@@ -133,14 +134,13 @@ func NewRegistryClient() *RegistryClient {
133134
// Authenticate obtains a token for the given image
134135
func (c *RegistryClient) Authenticate(ref ImageReference) error {
135136
if err := ValidateImageReference(ref); err != nil {
136-
return fmt.Errorf("invalid image reference: %w", err)
137+
return fmt.Errorf(invalidImageReferenceFormat, err)
137138
}
138139

139140
creds, hasCredentials := GetCredentials(ref.Registry)
141+
c.username = "anonymous"
140142
if hasCredentials {
141143
c.username = creds.Username
142-
} else {
143-
c.username = "anonymous"
144144
}
145145

146146
registryURL, err := buildRegistryURL(ref.Registry, "/v2/")
@@ -153,11 +153,12 @@ func (c *RegistryClient) Authenticate(ref ImageReference) error {
153153
}
154154
defer closeWithLog(resp.Body, responseBodyStr)
155155

156-
if resp.StatusCode == http.StatusOK {
156+
switch resp.StatusCode {
157+
case http.StatusOK:
157158
return nil // No auth required
158-
}
159-
160-
if resp.StatusCode != http.StatusUnauthorized {
159+
case http.StatusUnauthorized:
160+
// continue to token exchange below
161+
default:
161162
return fmt.Errorf("unexpected status: %d", resp.StatusCode)
162163
}
163164

@@ -168,56 +169,70 @@ func (c *RegistryClient) Authenticate(ref ImageReference) error {
168169

169170
realm, service, scope := parseAuthHeader(authHeader, ref.Repository)
170171

171-
// Validate the realm URL to prevent SSRF
172+
if err := validateAuthRealm(realm); err != nil {
173+
return err
174+
}
175+
176+
token, err := c.fetchToken(realm, service, scope, creds, hasCredentials)
177+
if err != nil {
178+
return err
179+
}
180+
c.token = token
181+
return nil
182+
}
183+
184+
// validateAuthRealm checks that the token realm URL is safe to contact (SSRF protection).
185+
func validateAuthRealm(realm string) error {
172186
parsedRealm, err := url.Parse(realm)
173187
if err != nil {
174188
return fmt.Errorf("invalid auth realm URL: %w", err)
175189
}
176190
if parsedRealm.Scheme != "https" && parsedRealm.Scheme != "http" {
177191
return fmt.Errorf("invalid auth realm scheme: %s", parsedRealm.Scheme)
178192
}
179-
// Validate the realm host to prevent SSRF to internal/private networks
180193
if err := validateRegistry(parsedRealm.Host); err != nil {
181194
return fmt.Errorf("invalid auth realm host: %w", err)
182195
}
196+
return nil
197+
}
183198

199+
// fetchToken requests a Bearer token from the auth realm and returns it.
200+
func (c *RegistryClient) fetchToken(realm, service, scope string, creds RegistryCredentials, hasCredentials bool) (string, error) {
184201
tokenURL := fmt.Sprintf("%s?service=%s&scope=%s", realm, url.QueryEscape(service), url.QueryEscape(scope))
185202

186203
req, err := http.NewRequest("GET", tokenURL, nil)
187204
if err != nil {
188-
return err
205+
return "", err
189206
}
190-
191207
if hasCredentials {
192208
auth := base64.StdEncoding.EncodeToString([]byte(creds.Username + ":" + creds.Password))
193209
req.Header.Set("Authorization", "Basic "+auth)
194210
}
195211

196-
resp, err = c.httpClient.Do(req)
212+
resp, err := c.httpClient.Do(req)
197213
if err != nil {
198-
return err
214+
return "", err
199215
}
200216
defer closeWithLog(resp.Body, responseBodyStr)
201217

202218
if resp.StatusCode != http.StatusOK {
203219
body, _ := io.ReadAll(io.LimitReader(resp.Body, 4096))
204-
return fmt.Errorf("authentication failed: %d - %s", resp.StatusCode, string(body))
220+
return "", fmt.Errorf("authentication failed: %d - %s", resp.StatusCode, string(body))
205221
}
206222

207223
var tokenResp struct {
208224
Token string `json:"token"`
209225
AccessToken string `json:"access_token"`
210226
}
211227
if err := json.NewDecoder(resp.Body).Decode(&tokenResp); err != nil {
212-
return err
228+
return "", err
213229
}
214230

215-
c.token = tokenResp.Token
216-
if c.token == "" {
217-
c.token = tokenResp.AccessToken
231+
token := tokenResp.Token
232+
if token == "" {
233+
token = tokenResp.AccessToken
218234
}
219-
220-
return nil
235+
return token, nil
221236
}
222237

223238
// GetAuthenticatedUser returns the username used for authentication
@@ -370,7 +385,7 @@ func (c *RegistryClient) doSafeRegistryRequest(registry, pathFormat string, head
370385
// fetchManifestResponse fetches the raw manifest response from the registry.
371386
func (c *RegistryClient) fetchManifestResponse(ref ImageReference, reference string) (*http.Response, error) {
372387
if err := ValidateImageReference(ref); err != nil {
373-
return nil, fmt.Errorf("invalid image reference: %w", err)
388+
return nil, fmt.Errorf(invalidImageReferenceFormat, err)
374389
}
375390

376391
headers := map[string]string{"Accept": manifestAcceptHeader}
@@ -379,10 +394,6 @@ func (c *RegistryClient) fetchManifestResponse(ref ImageReference, reference str
379394

380395
// getManifest retrieves the image manifest for the given platform
381396
func (c *RegistryClient) getManifest(ref ImageReference, platform Platform) (*ManifestV2, error) {
382-
if err := ValidateImageReference(ref); err != nil {
383-
return nil, fmt.Errorf("invalid image reference: %w", err)
384-
}
385-
386397
resp, err := c.fetchManifestResponse(ref, ref.Tag)
387398
if err != nil {
388399
return nil, err
@@ -405,7 +416,7 @@ func (c *RegistryClient) getManifest(ref ImageReference, platform Platform) (*Ma
405416

406417
func (c *RegistryClient) getManifestByDigest(ref ImageReference, digest string) (*ManifestV2, error) {
407418
if err := ValidateImageReference(ref); err != nil {
408-
return nil, fmt.Errorf("invalid image reference: %w", err)
419+
return nil, fmt.Errorf(invalidImageReferenceFormat, err)
409420
}
410421

411422
if err := validateDigest(digest); err != nil {
@@ -436,7 +447,7 @@ func (c *RegistryClient) getManifestByDigest(ref ImageReference, digest string)
436447
// DownloadBlob downloads a blob to a file
437448
func (c *RegistryClient) DownloadBlob(ref ImageReference, digest, destPath string) error {
438449
if err := ValidateImageReference(ref); err != nil {
439-
return fmt.Errorf("invalid image reference: %w", err)
450+
return fmt.Errorf(invalidImageReferenceFormat, err)
440451
}
441452

442453
if err := validateDigest(digest); err != nil {

server.go

Lines changed: 55 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -115,42 +115,16 @@ func (s *Server) logoHandler(w http.ResponseWriter, _ *http.Request) {
115115

116116
// imageHandler handles the /image endpoint
117117
func (s *Server) imageHandler(w http.ResponseWriter, r *http.Request) {
118-
119-
imageName := r.URL.Query().Get("name")
120-
if imageName == "" {
121-
writeJSONError(w, "missing required 'name' query parameter", http.StatusBadRequest)
118+
imageName, ok := extractImageName(w, r)
119+
if !ok {
122120
return
123121
}
124122

125-
imageName, err := sanitizeImageName(imageName)
126-
if err != nil {
127-
writeJSONError(w, fmt.Sprintf("invalid image name: %v", err), http.StatusBadRequest)
123+
platform, ok := platformFromRequest(w, r)
124+
if !ok {
128125
return
129126
}
130127

131-
platform := DefaultPlatform()
132-
if osParam := r.URL.Query().Get("os"); osParam != "" {
133-
if err := validatePlatformParam("os", osParam); err != nil {
134-
writeJSONError(w, err.Error(), http.StatusBadRequest)
135-
return
136-
}
137-
platform.OS = osParam
138-
}
139-
if archParam := r.URL.Query().Get("arch"); archParam != "" {
140-
if err := validatePlatformParam("arch", archParam); err != nil {
141-
writeJSONError(w, err.Error(), http.StatusBadRequest)
142-
return
143-
}
144-
platform.Architecture = archParam
145-
}
146-
if variantParam := r.URL.Query().Get("variant"); variantParam != "" {
147-
if err := validatePlatformParam("variant", variantParam); err != nil {
148-
writeJSONError(w, err.Error(), http.StatusBadRequest)
149-
return
150-
}
151-
platform.Variant = variantParam
152-
}
153-
154128
cachePath := s.cache.GetCachePath(imageName, platform)
155129

156130
if _, err := os.Stat(cachePath); err == nil {
@@ -160,7 +134,7 @@ func (s *Server) imageHandler(w http.ResponseWriter, r *http.Request) {
160134
}
161135

162136
log.Printf("Downloading image: %s (%s)\n", imageName, platform)
163-
sfKey := fmt.Sprintf("%s_%s_%s_%s", imageName, platform.OS, platform.Architecture, platform.Variant)
137+
sfKey := imageName + "_" + platform.String()
164138
result, err, _ := s.downloadGroup.Do(sfKey, func() (interface{}, error) {
165139
return DownloadImage(imageName, s.cache.Dir(), platform)
166140
})
@@ -177,15 +151,8 @@ func (s *Server) imageHandler(w http.ResponseWriter, r *http.Request) {
177151

178152
// platformsHandler handles the /platforms endpoint
179153
func (s *Server) platformsHandler(w http.ResponseWriter, r *http.Request) {
180-
imageName := r.URL.Query().Get("name")
181-
if imageName == "" {
182-
writeJSONError(w, "missing required 'name' query parameter", http.StatusBadRequest)
183-
return
184-
}
185-
186-
imageName, err := sanitizeImageName(imageName)
187-
if err != nil {
188-
writeJSONError(w, fmt.Sprintf("invalid image name: %v", err), http.StatusBadRequest)
154+
imageName, ok := extractImageName(w, r)
155+
if !ok {
189156
return
190157
}
191158

@@ -199,12 +166,46 @@ func (s *Server) platformsHandler(w http.ResponseWriter, r *http.Request) {
199166
platforms = []Platform{}
200167
}
201168

202-
w.Header().Set(contentTypeHeader, "application/json")
203-
if err := json.NewEncoder(w).Encode(map[string]interface{}{
204-
"platforms": platforms,
205-
}); err != nil {
206-
log.Printf("Failed to write platforms response: %v\n", err)
169+
writeJSON(w, http.StatusOK, map[string]interface{}{"platforms": platforms})
170+
}
171+
172+
// extractImageName reads and sanitizes the "name" query parameter, writing an
173+
// error response and returning false if it is missing or invalid.
174+
func extractImageName(w http.ResponseWriter, r *http.Request) (string, bool) {
175+
imageName := r.URL.Query().Get("name")
176+
if imageName == "" {
177+
writeJSONError(w, "missing required 'name' query parameter", http.StatusBadRequest)
178+
return "", false
179+
}
180+
imageName, err := sanitizeImageName(imageName)
181+
if err != nil {
182+
writeJSONError(w, fmt.Sprintf("invalid image name: %v", err), http.StatusBadRequest)
183+
return "", false
207184
}
185+
return imageName, true
186+
}
187+
188+
// platformFromRequest parses and validates the os/arch/variant query parameters,
189+
// writing an error response and returning false if any value is invalid.
190+
func platformFromRequest(w http.ResponseWriter, r *http.Request) (Platform, bool) {
191+
platform := DefaultPlatform()
192+
for _, field := range []struct {
193+
param string
194+
dest *string
195+
}{
196+
{"os", &platform.OS},
197+
{"arch", &platform.Architecture},
198+
{"variant", &platform.Variant},
199+
} {
200+
if val := r.URL.Query().Get(field.param); val != "" {
201+
if err := validatePlatformParam(field.param, val); err != nil {
202+
writeJSONError(w, err.Error(), http.StatusBadRequest)
203+
return Platform{}, false
204+
}
205+
*field.dest = val
206+
}
207+
}
208+
return platform, true
208209
}
209210

210211
// serveImageFile streams an image tar file to the response with Range request support
@@ -246,17 +247,20 @@ func (s *Server) serveImageFile(w http.ResponseWriter, r *http.Request, imagePat
246247
pullsCountMetric.Inc()
247248
}
248249

249-
// writeJSONError writes a JSON error response
250-
func writeJSONError(w http.ResponseWriter, message string, statusCode int) {
250+
// writeJSON writes a JSON response with the given status code.
251+
func writeJSON(w http.ResponseWriter, statusCode int, v interface{}) {
251252
w.Header().Set(contentTypeHeader, "application/json")
252253
w.WriteHeader(statusCode)
253-
err := json.NewEncoder(w).Encode(map[string]string{"error": message})
254-
if err != nil {
255-
errorsTotalMetric.Inc()
256-
log.Printf("Failed to write JSON error response: %v\n", err)
254+
if err := json.NewEncoder(w).Encode(v); err != nil {
255+
log.Printf("Failed to write JSON response: %v\n", err)
257256
}
258257
}
259258

259+
// writeJSONError writes a JSON error response.
260+
func writeJSONError(w http.ResponseWriter, message string, statusCode int) {
261+
writeJSON(w, statusCode, map[string]string{"error": message})
262+
}
263+
260264
// humanizeBytes converts bytes to a human-readable format
261265
func humanizeBytes(bytes int64) string {
262266
const unit = 1024

0 commit comments

Comments
 (0)