Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/core/services/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ func (s *InstanceService) LaunchInstance(ctx context.Context, params ports.Launc
// Rollback quota reservation on enqueue failure
_ = s.tenantSvc.DecrementUsage(ctx, tenantID, "vcpus", it.VCPUs)
_ = s.tenantSvc.DecrementUsage(ctx, tenantID, "memory", it.MemoryMB/1024)
// Rollback database record creation
if delErr := s.repo.Delete(ctx, inst.ID); delErr != nil {
s.logger.Error("failed to delete instance record after enqueue failure", "instance_id", inst.ID, "error", delErr)
}
return nil, errors.Wrap(errors.Internal, "failed to enqueue provisioning task", err)
}

Expand Down Expand Up @@ -313,6 +317,10 @@ func (s *InstanceService) LaunchInstanceWithOptions(ctx context.Context, opts po

if err := s.taskQueue.Enqueue(ctx, "provision_queue", job); err != nil {
s.logger.Error("failed to enqueue provision job", "instance_id", inst.ID, "error", err)
// Rollback database record creation
if delErr := s.repo.Delete(ctx, inst.ID); delErr != nil {
s.logger.Error("failed to delete instance record after enqueue failure", "instance_id", inst.ID, "error", delErr)
}
return nil, errors.Wrap(errors.Internal, "failed to enqueue provisioning task", err)
}

Expand Down
109 changes: 105 additions & 4 deletions internal/core/services/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ func (r *FaultyInstanceRepository) Create(ctx context.Context, instance *domain.
}

type InMemoryTaskQueue struct {
jobs []string
mu sync.Mutex
jobs []string
mu sync.Mutex
ShouldFail bool
}

func (q *InMemoryTaskQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.ShouldFail {
return fmt.Errorf("simulated enqueue failure")
}
q.jobs = append(q.jobs, fmt.Sprintf("%v", payload))
return nil
}
Expand Down Expand Up @@ -334,7 +338,7 @@ func TestInstanceServiceLaunchDBFailure(t *testing.T) {
assert.Contains(t, err.Error(), "simulated database failure")

// Verify no junk in DB (using real repo to check)
list, err := realRepo.List(ctx)
list, err := realRepo.List(ctx, nil)
require.NoError(t, err)
assert.Empty(t, list)
}
Expand Down Expand Up @@ -413,7 +417,7 @@ func TestInstanceServiceLaunchConcurrency(t *testing.T) {
}

// Verify all created
list, err := repo.List(ctx)
list, err := repo.List(ctx, nil)
require.NoError(t, err)
assert.Len(t, list, concurrency)

Expand Down Expand Up @@ -709,3 +713,100 @@ func TestLaunchInstanceWithOptions(t *testing.T) {
_ = compute.DeleteInstance(ctx, inst.ContainerID)
}
}

func TestInstanceServiceLaunchEnqueueFailure(t *testing.T) {
db := setupDB(t)
ctx := setupTestUser(t, db)

repo := postgres.NewInstanceRepository(db)
vpcRepo := postgres.NewVpcRepository(db)
subnetRepo := postgres.NewSubnetRepository(db)
volumeRepo := postgres.NewVolumeRepository(db)
itRepo := postgres.NewInstanceTypeRepository(db)

compute := noop.NewNoopComputeBackend()

defaultType := &domain.InstanceType{ID: testInstanceType, Name: "Basic 2", VCPUs: 1, MemoryMB: 128, DiskGB: 1}
_, err := itRepo.Create(ctx, defaultType)
require.NoError(t, err)

rbacSvc := new(MockRBACService)
rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)

eventSvc := services.NewEventService(services.EventServiceParams{
Repo: postgres.NewEventRepository(db),
RBACSvc: rbacSvc,
Publisher: nil,
Logger: slog.Default(),
})
auditSvc := services.NewAuditService(services.AuditServiceParams{
Repo: postgres.NewAuditRepository(db),
RBACSvc: rbacSvc,
})

// Create a task queue that fails
taskQueue := &InMemoryTaskQueue{ShouldFail: true}

sshKeySvc, err := services.NewSSHKeyService(services.SSHKeyServiceParams{
Repo: postgres.NewSSHKeyRepo(db),
RBACSvc: rbacSvc,
})
require.NoError(t, err)

tenantSvc := services.NewTenantService(services.TenantServiceParams{
Repo: postgres.NewTenantRepo(db),
UserRepo: postgres.NewUserRepo(db),
RBACSvc: rbacSvc,
Logger: slog.Default(),
})

svc := services.NewInstanceService(services.InstanceServiceParams{
Repo: repo,
VpcRepo: vpcRepo,
SubnetRepo: subnetRepo,
VolumeRepo: volumeRepo,
InstanceTypeRepo: itRepo,
RBAC: rbacSvc,
Compute: compute,
EventSvc: eventSvc,
AuditSvc: auditSvc,
TaskQueue: taskQueue,
Logger: slog.Default(),
TenantSvc: tenantSvc,
SSHKeySvc: sshKeySvc,
})

// Attempt LaunchInstance
name := "enqueue-fail-integration"
_, err = svc.LaunchInstance(ctx, coreports.LaunchParams{
Name: name,
Image: testImage,
InstanceType: testInstanceType,
})

// Verify Failure
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

// Verify instance not in DB
inst, err := repo.GetByName(ctx, name)
require.Error(t, err)
assert.Nil(t, inst)

// Attempt LaunchInstanceWithOptions
optsName := "enqueue-fail-opts-integration"
opts := coreports.CreateInstanceOptions{
Name: optsName,
ImageName: testImage,
}
_, err = svc.LaunchInstanceWithOptions(ctx, opts)

// Verify Failure
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

// Verify instance not in DB
instOpts, err := repo.GetByName(ctx, optsName)
require.Error(t, err)
assert.Nil(t, instOpts)
}
154 changes: 154 additions & 0 deletions internal/core/services/instance_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,160 @@ func testInstanceServiceLaunchInstanceUnit(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "quota exceeded")
})

t.Run("EnqueueFailure", func(t *testing.T) {
params := ports.LaunchParams{
Name: "enqueue-fail-inst",
Image: "alpine",
InstanceType: "t2.micro",
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
typeRepo.On("GetByID", mock.Anything, "t2.micro").Return(&domain.InstanceType{
ID: "t2.micro", VCPUs: 1, MemoryMB: 1024,
}, nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "instances", 1).Return(nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "memory", 1).Return(nil).Once()
tenantSvc.On("IncrementUsage", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("IncrementUsage", mock.Anything, tenantID, "memory", 1).Return(nil).Once()

var createdID uuid.UUID
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
createdID = i.ID
return i.Name == params.Name && i.UserID == userID
})).Return(nil).Once()

taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(errors.New("enqueue error")).Once()
tenantSvc.On("DecrementUsage", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("DecrementUsage", mock.Anything, tenantID, "memory", 1).Return(nil).Once()
repo.On("Delete", mock.Anything, mock.MatchedBy(func(id uuid.UUID) bool {
return id == createdID
})).Return(nil).Once()

_, err := svc.LaunchInstance(ctx, params)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

repo.AssertExpectations(t)
tenantSvc.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})

t.Run("EnqueueAndRollbackFailure", func(t *testing.T) {
params := ports.LaunchParams{
Name: "enqueue-fail-inst-rollback-fail",
Image: "alpine",
InstanceType: "t2.micro",
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
typeRepo.On("GetByID", mock.Anything, "t2.micro").Return(&domain.InstanceType{
ID: "t2.micro", VCPUs: 1, MemoryMB: 1024,
}, nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "instances", 1).Return(nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "memory", 1).Return(nil).Once()
tenantSvc.On("IncrementUsage", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("IncrementUsage", mock.Anything, tenantID, "memory", 1).Return(nil).Once()

var createdID uuid.UUID
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
createdID = i.ID
return i.Name == params.Name && i.UserID == userID
})).Return(nil).Once()

taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(errors.New("enqueue error")).Once()
tenantSvc.On("DecrementUsage", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("DecrementUsage", mock.Anything, tenantID, "memory", 1).Return(nil).Once()
repo.On("Delete", mock.Anything, mock.MatchedBy(func(id uuid.UUID) bool {
return id == createdID
})).Return(errors.New("delete error")).Once()

_, err := svc.LaunchInstance(ctx, params)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")
assert.NotContains(t, err.Error(), "delete error")

repo.AssertExpectations(t)
tenantSvc.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})

t.Run("LaunchInstanceWithOptions_Success", func(t *testing.T) {
opts := ports.CreateInstanceOptions{
Name: "opts-success",
ImageName: "alpine",
Ports: []string{"80:80"},
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
return i.Name == opts.Name
})).Return(nil).Once()
taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(nil).Once()

inst, err := svc.LaunchInstanceWithOptions(ctx, opts)
require.NoError(t, err)
assert.NotNil(t, inst)
assert.Equal(t, opts.Name, inst.Name)

repo.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})

t.Run("LaunchInstanceWithOptions_EnqueueFailure", func(t *testing.T) {
opts := ports.CreateInstanceOptions{
Name: "opts-fail",
ImageName: "alpine",
Ports: []string{"80:80"},
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
var createdID uuid.UUID
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
createdID = i.ID
return i.Name == opts.Name
})).Return(nil).Once()
taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(errors.New("enqueue error")).Once()
repo.On("Delete", mock.Anything, mock.MatchedBy(func(id uuid.UUID) bool {
return id == createdID
})).Return(nil).Once()

_, err := svc.LaunchInstanceWithOptions(ctx, opts)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

repo.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})

t.Run("LaunchInstanceWithOptions_EnqueueAndRollbackFailure", func(t *testing.T) {
opts := ports.CreateInstanceOptions{
Name: "opts-fail-rollback-fail",
ImageName: "alpine",
Ports: []string{"80:80"},
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
var createdID uuid.UUID
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
createdID = i.ID
return i.Name == opts.Name
})).Return(nil).Once()
taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(errors.New("enqueue error")).Once()
repo.On("Delete", mock.Anything, mock.MatchedBy(func(id uuid.UUID) bool {
return id == createdID
})).Return(errors.New("delete error")).Once()

_, err := svc.LaunchInstanceWithOptions(ctx, opts)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")
assert.NotContains(t, err.Error(), "delete error")

repo.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})
}

func testInstanceServiceLifecycleUnit(t *testing.T) {
Expand Down
Loading