Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions internal/core/services/instance.go
Original file line number Diff line number Diff line change
Expand Up @@ -246,6 +246,10 @@ func (s *InstanceService) LaunchInstance(ctx context.Context, params ports.Launc
// Rollback quota reservation on enqueue failure
_ = s.tenantSvc.DecrementUsage(ctx, tenantID, "vcpus", it.VCPUs)
_ = s.tenantSvc.DecrementUsage(ctx, tenantID, "memory", it.MemoryMB/1024)
// Rollback database record creation
if delErr := s.repo.Delete(ctx, inst.ID); delErr != nil {
s.logger.Error("failed to delete instance record after enqueue failure", "instance_id", inst.ID, "error", delErr)
}
return nil, errors.Wrap(errors.Internal, "failed to enqueue provisioning task", err)
}

Expand Down Expand Up @@ -313,6 +317,10 @@ func (s *InstanceService) LaunchInstanceWithOptions(ctx context.Context, opts po

if err := s.taskQueue.Enqueue(ctx, "provision_queue", job); err != nil {
s.logger.Error("failed to enqueue provision job", "instance_id", inst.ID, "error", err)
// Rollback database record creation
if delErr := s.repo.Delete(ctx, inst.ID); delErr != nil {
s.logger.Error("failed to delete instance record after enqueue failure", "instance_id", inst.ID, "error", delErr)
}
return nil, errors.Wrap(errors.Internal, "failed to enqueue provisioning task", err)
}

Expand Down
109 changes: 105 additions & 4 deletions internal/core/services/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,13 +47,17 @@ func (r *FaultyInstanceRepository) Create(ctx context.Context, instance *domain.
}

type InMemoryTaskQueue struct {
jobs []string
mu sync.Mutex
jobs []string
mu sync.Mutex
ShouldFail bool
}

func (q *InMemoryTaskQueue) Enqueue(ctx context.Context, queueName string, payload interface{}) error {
q.mu.Lock()
defer q.mu.Unlock()
if q.ShouldFail {
return fmt.Errorf("simulated enqueue failure")
}
q.jobs = append(q.jobs, fmt.Sprintf("%v", payload))
return nil
}
Expand Down Expand Up @@ -334,7 +338,7 @@ func TestInstanceServiceLaunchDBFailure(t *testing.T) {
assert.Contains(t, err.Error(), "simulated database failure")

// Verify no junk in DB (using real repo to check)
list, err := realRepo.List(ctx)
list, err := realRepo.List(ctx, nil)
require.NoError(t, err)
assert.Empty(t, list)
}
Expand Down Expand Up @@ -413,7 +417,7 @@ func TestInstanceServiceLaunchConcurrency(t *testing.T) {
}

// Verify all created
list, err := repo.List(ctx)
list, err := repo.List(ctx, nil)
require.NoError(t, err)
assert.Len(t, list, concurrency)

Expand Down Expand Up @@ -709,3 +713,100 @@ func TestLaunchInstanceWithOptions(t *testing.T) {
_ = compute.DeleteInstance(ctx, inst.ContainerID)
}
}

func TestInstanceServiceLaunchEnqueueFailure(t *testing.T) {
db := setupDB(t)
ctx := setupTestUser(t, db)

repo := postgres.NewInstanceRepository(db)
vpcRepo := postgres.NewVpcRepository(db)
subnetRepo := postgres.NewSubnetRepository(db)
volumeRepo := postgres.NewVolumeRepository(db)
itRepo := postgres.NewInstanceTypeRepository(db)

compute := noop.NewNoopComputeBackend()

defaultType := &domain.InstanceType{ID: testInstanceType, Name: "Basic 2", VCPUs: 1, MemoryMB: 128, DiskGB: 1}
_, _ = itRepo.Create(ctx, defaultType)
Comment thread
coderabbitai[bot] marked this conversation as resolved.
Outdated

rbacSvc := new(MockRBACService)
rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil)

eventSvc := services.NewEventService(services.EventServiceParams{
Repo: postgres.NewEventRepository(db),
RBACSvc: rbacSvc,
Publisher: nil,
Logger: slog.Default(),
})
auditSvc := services.NewAuditService(services.AuditServiceParams{
Repo: postgres.NewAuditRepository(db),
RBACSvc: rbacSvc,
})

// Create a task queue that fails
taskQueue := &InMemoryTaskQueue{ShouldFail: true}

sshKeySvc, err := services.NewSSHKeyService(services.SSHKeyServiceParams{
Repo: postgres.NewSSHKeyRepo(db),
RBACSvc: rbacSvc,
})
require.NoError(t, err)

tenantSvc := services.NewTenantService(services.TenantServiceParams{
Repo: postgres.NewTenantRepo(db),
UserRepo: postgres.NewUserRepo(db),
RBACSvc: rbacSvc,
Logger: slog.Default(),
})

svc := services.NewInstanceService(services.InstanceServiceParams{
Repo: repo,
VpcRepo: vpcRepo,
SubnetRepo: subnetRepo,
VolumeRepo: volumeRepo,
InstanceTypeRepo: itRepo,
RBAC: rbacSvc,
Compute: compute,
EventSvc: eventSvc,
AuditSvc: auditSvc,
TaskQueue: taskQueue,
Logger: slog.Default(),
TenantSvc: tenantSvc,
SSHKeySvc: sshKeySvc,
})

// Attempt LaunchInstance
name := "enqueue-fail-integration"
_, err = svc.LaunchInstance(ctx, coreports.LaunchParams{
Name: name,
Image: testImage,
InstanceType: testInstanceType,
})

// Verify Failure
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

// Verify instance not in DB
inst, err := repo.GetByName(ctx, name)
require.Error(t, err)
assert.Nil(t, inst)

// Attempt LaunchInstanceWithOptions
optsName := "enqueue-fail-opts-integration"
opts := coreports.CreateInstanceOptions{
Name: optsName,
ImageName: testImage,
}
_, err = svc.LaunchInstanceWithOptions(ctx, opts)

// Verify Failure
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

// Verify instance not in DB
instOpts, err := repo.GetByName(ctx, optsName)
require.Error(t, err)
assert.Nil(t, instOpts)
}

87 changes: 87 additions & 0 deletions internal/core/services/instance_unit_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,93 @@ func testInstanceServiceLaunchInstanceUnit(t *testing.T) {
require.Error(t, err)
assert.Contains(t, err.Error(), "quota exceeded")
})

t.Run("EnqueueFailure", func(t *testing.T) {
params := ports.LaunchParams{
Name: "enqueue-fail-inst",
Image: "alpine",
InstanceType: "t2.micro",
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
typeRepo.On("GetByID", mock.Anything, "t2.micro").Return(&domain.InstanceType{
ID: "t2.micro", VCPUs: 1, MemoryMB: 1024,
}, nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "instances", 1).Return(nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("CheckQuota", mock.Anything, tenantID, "memory", 1).Return(nil).Once()
tenantSvc.On("IncrementUsage", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("IncrementUsage", mock.Anything, tenantID, "memory", 1).Return(nil).Once()

var createdID uuid.UUID
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
createdID = i.ID
return i.Name == params.Name && i.UserID == userID
})).Return(nil).Once()

taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(errors.New("enqueue error")).Once()
tenantSvc.On("DecrementUsage", mock.Anything, tenantID, "vcpus", 1).Return(nil).Once()
tenantSvc.On("DecrementUsage", mock.Anything, tenantID, "memory", 1).Return(nil).Once()
repo.On("Delete", mock.Anything, mock.MatchedBy(func(id uuid.UUID) bool {
return id == createdID
})).Return(nil).Once()

_, err := svc.LaunchInstance(ctx, params)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

repo.AssertExpectations(t)
tenantSvc.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})

t.Run("LaunchInstanceWithOptions_Success", func(t *testing.T) {
opts := ports.CreateInstanceOptions{
Name: "opts-success",
ImageName: "alpine",
Ports: []string{"80:80"},
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
return i.Name == opts.Name
})).Return(nil).Once()
taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(nil).Once()

inst, err := svc.LaunchInstanceWithOptions(ctx, opts)
require.NoError(t, err)
assert.NotNil(t, inst)
assert.Equal(t, opts.Name, inst.Name)

repo.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})

t.Run("LaunchInstanceWithOptions_EnqueueFailure", func(t *testing.T) {
opts := ports.CreateInstanceOptions{
Name: "opts-fail",
ImageName: "alpine",
Ports: []string{"80:80"},
}

rbacSvc.On("Authorize", mock.Anything, mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once()
var createdID uuid.UUID
repo.On("Create", mock.Anything, mock.MatchedBy(func(i *domain.Instance) bool {
createdID = i.ID
return i.Name == opts.Name
})).Return(nil).Once()
taskQueue.On("Enqueue", mock.Anything, "provision_queue", mock.Anything).Return(errors.New("enqueue error")).Once()
repo.On("Delete", mock.Anything, mock.MatchedBy(func(id uuid.UUID) bool {
return id == createdID
})).Return(nil).Once()

_, err := svc.LaunchInstanceWithOptions(ctx, opts)
require.Error(t, err)
assert.Contains(t, err.Error(), "failed to enqueue provisioning task")

repo.AssertExpectations(t)
taskQueue.AssertExpectations(t)
})
}

func testInstanceServiceLifecycleUnit(t *testing.T) {
Expand Down
Loading