Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 47 additions & 15 deletions pkg/apiclients/testapi/testapi.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,13 +171,27 @@ var (
)

// StartTestParams defines parameters for the high-level StartTest function.
type StartTestParams struct {
type StartTestParams interface {
isStartTestParams()
}

type StartTestSubjectParams struct {
OrgID string
Subject TestSubjectCreate
LocalPolicy *LocalPolicy
}

func (p StartTestSubjectParams) isStartTestParams() {}

type StartTestResourcesParams struct {
OrgID string
Subject *TestSubjectCreate
Resources *[]TestResourceCreateItem
Resources []TestResourceCreateItem
ScanConfig ScanConfiguration
LocalPolicy *LocalPolicy
}

func (p StartTestResourcesParams) isStartTestParams() {}

// testResult is the concrete implementation of the TestResult interface for
// accessing summary and findings data of a completed test.
type testResult struct {
Expand Down Expand Up @@ -300,38 +314,56 @@ func NewTestClient(serverBaseUrl string, options ...ConfigOption) (TestClient, e

// Create the initial test and return a handle to poll it
func (c *client) StartTest(ctx context.Context, params StartTestParams) (TestHandle, error) {
if params.Resources != nil {
if len(*params.Resources) > 0 {
for i, resource := range *params.Resources {
var orgID string
var subject *TestSubjectCreate
var resources *[]TestResourceCreateItem
var localPolicy *LocalPolicy
var scanConfig *ScanConfiguration

switch p := params.(type) {
case StartTestResourcesParams:
orgID = p.OrgID
resources = &p.Resources
localPolicy = p.LocalPolicy
scanConfig = &p.ScanConfig

if len(*resources) > 0 {
for i, resource := range *resources {
if len(resource.union) == 0 {
return nil, fmt.Errorf("resource at index %d is required in StartTestParams and must be populated", i)
}
}
} else {
return nil, fmt.Errorf("resources do not contain any items in StartTestParams")
}
} else if params.Subject != nil {
if len(params.Subject.union) == 0 {
case StartTestSubjectParams:
orgID = p.OrgID
subject = &p.Subject
localPolicy = p.LocalPolicy

if len(subject.union) == 0 {
return nil, fmt.Errorf("subject is required in StartTestParams and must be populated")
}
} else {
return nil, fmt.Errorf("either resources or subject are required in StartTestParams and must be populated")

default:
return nil, fmt.Errorf("unsupported StartTestParams type: %T", params)
}

if params.OrgID == "" {
if orgID == "" {
return nil, fmt.Errorf("OrgID is required")
}
orgUUID, err := uuid.Parse(params.OrgID)
orgUUID, err := uuid.Parse(orgID)
if err != nil {
return nil, fmt.Errorf("invalid OrgID format: %w", err)
}

// Create test body
testAttributes := TestAttributesCreate{Subject: params.Subject, Resources: params.Resources}
testAttributes := TestAttributesCreate{Subject: subject, Resources: resources}

if params.LocalPolicy != nil {
if localPolicy != nil || scanConfig != nil {
testAttributes.Config = &TestConfiguration{
LocalPolicy: params.LocalPolicy,
LocalPolicy: localPolicy,
ScanConfig: scanConfig,
}
}

Expand Down