Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions docs/API-Docs.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ HTTP/JSON API exposed by `cmd/server`. All endpoints accept and return
- **Base URL:** `http://<host>:<port>` (default port `8080`, configurable via `HTTP_ADDR`)
- **Auth:** none currently enforced (deploy behind a trusted ingress / auth proxy)
- **Content-Type:** `application/json` is required on every request that has a body
- **Unknown fields:** request bodies with unknown JSON fields are rejected with `400`
- **Unknown fields / extra JSON values:** request bodies with unknown fields or trailing JSON data are rejected with `400`

---

Expand All @@ -31,7 +31,7 @@ Errors are returned with an appropriate HTTP status code and a JSON body:

| Status | Meaning | Triggered by |
|--------|---------|--------------|
| `400 Bad Request` | Malformed JSON, unknown field, missing required field, or unknown foreign-key reference | request body validation, `service.ErrInvalidInput` |
| `400 Bad Request` | Malformed JSON, missing `application/json` content type, unknown field, trailing JSON data, missing required field, or unknown foreign-key reference | request body validation, `service.ErrInvalidInput` |
| `404 Not Found` | Requested record does not exist | `service.ErrNotFound` |
| `409 Conflict` | Duplicate `email` or duplicate `originated_id` | `service.ErrAlreadyExists` |
| `500 Internal Server Error` | Unexpected server / database failure (driver message is logged, never returned) | any other error |
Expand Down
56 changes: 55 additions & 1 deletion internal/server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,16 @@ package server
import (
"encoding/json"
"errors"
"fmt"
"io"
"log/slog"
"mime"
"net/http"
"strings"
"time"

"github.com/go-sql-driver/mysql"

"github.com/apache/airavata-custos/pkg/models"
"github.com/apache/airavata-custos/pkg/service"
)
Expand All @@ -49,6 +54,7 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
}

func (s *Server) routes() {
s.mux.HandleFunc("GET /{$}", s.index)
s.mux.HandleFunc("GET /healthz", s.healthz)

s.mux.HandleFunc("POST /organizations", s.createOrganization)
Expand Down Expand Up @@ -145,6 +151,14 @@ func (s *Server) routes() {

}

func (s *Server) index(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]any{
"name": "Apache Airavata Custos API",
"health": "/healthz",
"documentation": "docs/API-Docs.md",
})
}

func (s *Server) healthz(w http.ResponseWriter, _ *http.Request) {
writeJSON(w, http.StatusOK, map[string]string{"status": "ok"})
}
Expand Down Expand Up @@ -1023,9 +1037,27 @@ func (r *statusRecorder) WriteHeader(code int) {
}

func decodeJSON(r *http.Request, dst any) error {
contentType := r.Header.Get("Content-Type")
if contentType == "" {
return errors.New("Content-Type must be application/json")
}
mediaType, _, err := mime.ParseMediaType(contentType)
if err != nil || mediaType != "application/json" {
return errors.New("Content-Type must be application/json")
}

dec := json.NewDecoder(r.Body)
dec.DisallowUnknownFields()
return dec.Decode(dst)
if err := dec.Decode(dst); err != nil {
return err
}
if err := dec.Decode(&struct{}{}); !errors.Is(err, io.EOF) {
if err == nil {
return errors.New("request body must contain a single JSON value")
}
return fmt.Errorf("invalid trailing data after JSON value: %w", err)
}
return nil
}

func writeJSON(w http.ResponseWriter, status int, body any) {
Expand All @@ -1049,9 +1081,31 @@ func writeServiceError(w http.ResponseWriter, err error) {
writeError(w, http.StatusConflict, err)
case errors.Is(err, service.ErrInvalidInput):
writeError(w, http.StatusBadRequest, err)
case isDuplicateKeyError(err):
writeError(w, http.StatusConflict, service.ErrAlreadyExists)
case isInvalidConstraintError(err):
writeError(w, http.StatusBadRequest, service.ErrInvalidInput)
default:
// Avoid leaking driver messages to clients; log the full error.
slog.Error("internal server error", "error", err.Error())
writeError(w, http.StatusInternalServerError, errors.New(strings.TrimSpace("internal server error")))
}
}

func isDuplicateKeyError(err error) bool {
var mysqlErr *mysql.MySQLError
return errors.As(err, &mysqlErr) && mysqlErr.Number == 1062
}

func isInvalidConstraintError(err error) bool {
var mysqlErr *mysql.MySQLError
if !errors.As(err, &mysqlErr) {
return false
}
switch mysqlErr.Number {
case 1292, 1451, 1452:
return true
default:
return false
}
}
53 changes: 53 additions & 0 deletions internal/server/server_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

package server

import (
"net/http/httptest"
"strings"
"testing"
)

func TestDecodeJSONRequiresApplicationJSON(t *testing.T) {
req := httptest.NewRequest("POST", "/organizations", strings.NewReader(`{"name":"Example"}`))

var body map[string]string
if err := decodeJSON(req, &body); err == nil {
t.Fatal("expected missing Content-Type to be rejected")
}
}

func TestDecodeJSONRejectsTrailingData(t *testing.T) {
req := httptest.NewRequest("POST", "/organizations", strings.NewReader(`{"name":"Example"} {"x":1}`))
req.Header.Set("Content-Type", "application/json")

var body map[string]string
if err := decodeJSON(req, &body); err == nil {
t.Fatal("expected trailing JSON data to be rejected")
}
}

func TestDecodeJSONAcceptsApplicationJSONWithCharset(t *testing.T) {
req := httptest.NewRequest("POST", "/organizations", strings.NewReader(`{"name":"Example"}`))
req.Header.Set("Content-Type", "application/json; charset=utf-8")

var body map[string]string
if err := decodeJSON(req, &body); err != nil {
t.Fatalf("expected valid JSON body: %v", err)
}
}
27 changes: 27 additions & 0 deletions pkg/service/compute_allocation.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,12 +42,21 @@ func (s *Service) CreateComputeAllocation(ctx context.Context, alloc *models.Com
if alloc.ComputeClusterID == "" {
return nil, fmt.Errorf("%w: compute_cluster_id is required", ErrInvalidInput)
}
if alloc.InitialSUAmount < 0 {
return nil, fmt.Errorf("%w: initial_su_amount must be non-negative", ErrInvalidInput)
}
if err := validateRequiredTimeRange(alloc.StartTime, alloc.EndTime); err != nil {
return nil, err
}
if alloc.ID == "" {
alloc.ID = newID()
}
if alloc.Status == "" {
alloc.Status = models.ACTIVE
}
if err := validateAllocationStatus("status", alloc.Status); err != nil {
return nil, err
}

if proj, err := s.projs.FindByID(ctx, alloc.ProjectID); err != nil {
return nil, fmt.Errorf("lookup project: %w", err)
Expand Down Expand Up @@ -107,6 +116,24 @@ func (s *Service) UpdateComputeAllocation(ctx context.Context, alloc *models.Com
if alloc == nil || alloc.ID == "" {
return fmt.Errorf("%w: compute allocation id is required", ErrInvalidInput)
}
if alloc.Name == "" {
return fmt.Errorf("%w: compute allocation name is required", ErrInvalidInput)
}
if alloc.ProjectID == "" {
return fmt.Errorf("%w: project_id is required", ErrInvalidInput)
}
if alloc.ComputeClusterID == "" {
return fmt.Errorf("%w: compute_cluster_id is required", ErrInvalidInput)
}
if alloc.InitialSUAmount < 0 {
return fmt.Errorf("%w: initial_su_amount must be non-negative", ErrInvalidInput)
}
if err := validateRequiredTimeRange(alloc.StartTime, alloc.EndTime); err != nil {
return err
}
if err := validateAllocationStatus("status", alloc.Status); err != nil {
return err
}
if err := s.inTx(ctx, func(tx *sql.Tx) error {
return s.allocs.Update(ctx, tx, alloc)
}); err != nil {
Expand Down
16 changes: 16 additions & 0 deletions pkg/service/compute_allocation_change_request.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,14 @@ func (s *Service) CreateComputeAllocationChangeRequest(ctx context.Context, req
if req.RequesterID == "" {
return nil, fmt.Errorf("%w: requester_id is required", ErrInvalidInput)
}
if req.RequestedSUAmount < 0 {
return nil, fmt.Errorf("%w: requested_su_amount must be non-negative", ErrInvalidInput)
}
if req.RequestedStatus != "" {
if err := validateAllocationStatus("requested_status", req.RequestedStatus); err != nil {
return nil, err
}
}

if alloc, err := s.allocs.FindByID(ctx, req.ComputeAllocationID); err != nil {
return nil, fmt.Errorf("lookup compute allocation: %w", err)
Expand Down Expand Up @@ -138,6 +146,14 @@ func (s *Service) UpdateComputeAllocationChangeRequest(ctx context.Context, req
if req.ChangeStatus == "" {
req.ChangeStatus = existing.ChangeStatus
}
if req.RequestedSUAmount < 0 {
return nil, fmt.Errorf("%w: requested_su_amount must be non-negative", ErrInvalidInput)
}
if req.RequestedStatus != "" {
if err := validateAllocationStatus("requested_status", req.RequestedStatus); err != nil {
return nil, err
}
}
if req.Timestamp.IsZero() {
req.Timestamp = existing.Timestamp
}
Expand Down
6 changes: 6 additions & 0 deletions pkg/service/compute_allocation_diff.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ func (s *Service) CreateComputeAllocationDiff(ctx context.Context, diff *models.
if diff.Status == "" {
return nil, fmt.Errorf("%w: status is required", ErrInvalidInput)
}
if diff.NewSUAmount < 0 {
return nil, fmt.Errorf("%w: new_su_amount must be non-negative", ErrInvalidInput)
}
if err := validateAllocationStatus("status", diff.Status); err != nil {
return nil, err
}

if alloc, err := s.allocs.FindByID(ctx, diff.ComputeAllocationID); err != nil {
return nil, fmt.Errorf("lookup compute allocation: %w", err)
Expand Down
15 changes: 15 additions & 0 deletions pkg/service/compute_allocation_membership.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ func (s *Service) CreateComputeAllocationMembership(ctx context.Context, m *mode
if m.UserID == "" {
return nil, fmt.Errorf("%w: user_id is required", ErrInvalidInput)
}
if err := validateRequiredTimeRange(m.StartTime, m.EndTime); err != nil {
return nil, err
}

if alloc, err := s.allocs.FindByID(ctx, m.ComputeAllocationID); err != nil {
return nil, fmt.Errorf("lookup compute allocation: %w", err)
Expand All @@ -64,6 +67,9 @@ func (s *Service) CreateComputeAllocationMembership(ctx context.Context, m *mode
if m.MembershipStatus == "" {
m.MembershipStatus = models.ACTIVE
}
if err := validateAllocationStatus("membership_status", m.MembershipStatus); err != nil {
return nil, err
}

if err := s.inTx(ctx, func(tx *sql.Tx) error {
return s.memberships.Create(ctx, tx, m)
Expand Down Expand Up @@ -143,6 +149,12 @@ func (s *Service) UpdateComputeAllocationMembership(ctx context.Context, m *mode
if m.MembershipStatus == "" {
m.MembershipStatus = existing.MembershipStatus
}
if err := validateRequiredTimeRange(m.StartTime, m.EndTime); err != nil {
return nil, err
}
if err := validateAllocationStatus("membership_status", m.MembershipStatus); err != nil {
return nil, err
}

if err := s.inTx(ctx, func(tx *sql.Tx) error {
return s.memberships.Update(ctx, tx, m)
Expand All @@ -163,6 +175,9 @@ func (s *Service) UpdateMembershipStatus(ctx context.Context, id string, status
if status == "" {
return nil, fmt.Errorf("%w: membership_status is required", ErrInvalidInput)
}
if err := validateAllocationStatus("membership_status", status); err != nil {
return nil, err
}
existing, err := s.memberships.FindByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("lookup compute allocation membership: %w", err)
Expand Down
12 changes: 12 additions & 0 deletions pkg/service/compute_allocation_resource.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ func (s *Service) CreateComputeAllocationResource(ctx context.Context, resource
if resource.ResourceType == "" {
return nil, fmt.Errorf("%w: resource_type is required", ErrInvalidInput)
}
if resource.ResourceAmount < 0 {
return nil, fmt.Errorf("%w: resource_amount must be non-negative", ErrInvalidInput)
}
if resource.ID == "" {
resource.ID = newID()
}
Expand Down Expand Up @@ -79,6 +82,15 @@ func (s *Service) UpdateComputeAllocationResource(ctx context.Context, resource
if resource == nil || resource.ID == "" {
return fmt.Errorf("%w: compute allocation resource id is required", ErrInvalidInput)
}
if resource.Name == "" {
return fmt.Errorf("%w: resource name is required", ErrInvalidInput)
}
if resource.ResourceType == "" {
return fmt.Errorf("%w: resource_type is required", ErrInvalidInput)
}
if resource.ResourceAmount < 0 {
return fmt.Errorf("%w: resource_amount must be non-negative", ErrInvalidInput)
}
if err := s.inTx(ctx, func(tx *sql.Tx) error {
return s.resources.Update(ctx, tx, resource)
}); err != nil {
Expand Down
9 changes: 9 additions & 0 deletions pkg/service/project.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ func (s *Service) CreateProject(ctx context.Context, project *models.Project) (*
if project.Status == "" {
project.Status = models.ProjectActive
}
if err := validateProjectStatus(project.Status); err != nil {
return nil, err
}
if project.CreatedTime.IsZero() {
project.CreatedTime = nowUTC()
}
Expand Down Expand Up @@ -134,6 +137,9 @@ func (s *Service) UpdateProject(ctx context.Context, project *models.Project) er
if project.Status == "" {
project.Status = existing.Status
}
if err := validateProjectStatus(project.Status); err != nil {
return err
}
if project.CreatedTime.IsZero() {
project.CreatedTime = existing.CreatedTime
}
Expand All @@ -156,6 +162,9 @@ func (s *Service) UpdateProjectStatus(ctx context.Context, id string, status mod
if status == "" {
return nil, fmt.Errorf("%w: status is required", ErrInvalidInput)
}
if err := validateProjectStatus(status); err != nil {
return nil, err
}
existing, err := s.projs.FindByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("lookup project: %w", err)
Expand Down
9 changes: 9 additions & 0 deletions pkg/service/user.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@ func (s *Service) CreateUser(ctx context.Context, user *models.User) (*models.Us
if user.Status == "" {
user.Status = models.UserActive
}
if err := validateUserStatus(user.Status); err != nil {
return nil, err
}

if err := s.inTx(ctx, func(tx *sql.Tx) error {
return s.users.Create(ctx, tx, user)
Expand Down Expand Up @@ -160,6 +163,9 @@ func (s *Service) UpdateUser(ctx context.Context, user *models.User) error {
if user.Status == "" {
user.Status = existing.Status
}
if err := validateUserStatus(user.Status); err != nil {
return err
}

if err := s.inTx(ctx, func(tx *sql.Tx) error {
return s.users.Update(ctx, tx, user)
Expand All @@ -180,6 +186,9 @@ func (s *Service) UpdateUserStatus(ctx context.Context, id string, status models
if status == "" {
return nil, fmt.Errorf("%w: status is required", ErrInvalidInput)
}
if err := validateUserStatus(status); err != nil {
return nil, err
}
existing, err := s.users.FindByID(ctx, id)
if err != nil {
return nil, fmt.Errorf("lookup user: %w", err)
Expand Down
Loading