refactor: use ormer from the ctx for scanner ctl mgr and dao (#14313)

Signed-off-by: He Weiwei <hweiwei@vmware.com>
This commit is contained in:
He Weiwei 2021-03-01 12:02:40 +08:00 committed by GitHub
parent 2fcdf3e09b
commit 9161a3fbdf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 412 additions and 500 deletions

View File

@ -56,8 +56,8 @@ type basicController struct {
}
// ListRegistrations ...
func (bc *basicController) ListRegistrations(query *q.Query) ([]*scanner.Registration, error) {
l, err := bc.manager.List(query)
func (bc *basicController) ListRegistrations(ctx context.Context, query *q.Query) ([]*scanner.Registration, error) {
l, err := bc.manager.List(ctx, query)
if err != nil {
return nil, errors.Wrap(err, "api controller: list registrations")
}
@ -66,18 +66,18 @@ func (bc *basicController) ListRegistrations(query *q.Query) ([]*scanner.Registr
}
// CreateRegistration ...
func (bc *basicController) CreateRegistration(registration *scanner.Registration) (string, error) {
func (bc *basicController) CreateRegistration(ctx context.Context, registration *scanner.Registration) (string, error) {
if isReservedName(registration.Name) {
return "", errors.BadRequestError(nil).WithMessage(`name "%s" is reserved, please try a different name`, registration.Name)
}
// Check if the registration is available
if _, err := bc.Ping(registration); err != nil {
if _, err := bc.Ping(ctx, registration); err != nil {
return "", errors.Wrap(err, "api controller: create registration")
}
// Check if there are any registrations already existing.
l, err := bc.manager.List(&q.Query{
l, err := bc.manager.List(ctx, &q.Query{
PageSize: 1,
PageNumber: 1,
})
@ -90,12 +90,12 @@ func (bc *basicController) CreateRegistration(registration *scanner.Registration
registration.IsDefault = true
}
return bc.manager.Create(registration)
return bc.manager.Create(ctx, registration)
}
// GetRegistration ...
func (bc *basicController) GetRegistration(registrationUUID string) (*scanner.Registration, error) {
r, err := bc.manager.Get(registrationUUID)
func (bc *basicController) GetRegistration(ctx context.Context, registrationUUID string) (*scanner.Registration, error) {
r, err := bc.manager.Get(ctx, registrationUUID)
if err != nil {
return nil, errors.Wrap(err, "api controller: get registration")
}
@ -104,8 +104,8 @@ func (bc *basicController) GetRegistration(registrationUUID string) (*scanner.Re
}
// RegistrationExists ...
func (bc *basicController) RegistrationExists(registrationUUID string) bool {
registration, err := bc.manager.Get(registrationUUID)
func (bc *basicController) RegistrationExists(ctx context.Context, registrationUUID string) bool {
registration, err := bc.manager.Get(ctx, registrationUUID)
// Just logged when an error occurred
if err != nil {
@ -116,7 +116,7 @@ func (bc *basicController) RegistrationExists(registrationUUID string) bool {
}
// UpdateRegistration ...
func (bc *basicController) UpdateRegistration(registration *scanner.Registration) error {
func (bc *basicController) UpdateRegistration(ctx context.Context, registration *scanner.Registration) error {
if registration.IsDefault && registration.Disabled {
return errors.Errorf("default registration %s can not be marked to disabled", registration.UUID)
}
@ -125,12 +125,12 @@ func (bc *basicController) UpdateRegistration(registration *scanner.Registration
return errors.BadRequestError(nil).WithMessage(`name "%s" is reserved, please try a different name`, registration.Name)
}
return bc.manager.Update(registration)
return bc.manager.Update(ctx, registration)
}
// SetDefaultRegistration ...
func (bc *basicController) DeleteRegistration(registrationUUID string) (*scanner.Registration, error) {
registration, err := bc.manager.Get(registrationUUID)
func (bc *basicController) DeleteRegistration(ctx context.Context, registrationUUID string) (*scanner.Registration, error) {
registration, err := bc.manager.Get(ctx, registrationUUID)
if err != nil {
return nil, errors.Wrap(err, "api controller: delete registration")
}
@ -140,7 +140,7 @@ func (bc *basicController) DeleteRegistration(registrationUUID string) (*scanner
return nil, nil
}
if err := bc.manager.Delete(registrationUUID); err != nil {
if err := bc.manager.Delete(ctx, registrationUUID); err != nil {
return nil, errors.Wrap(err, "api controller: delete registration")
}
@ -148,8 +148,8 @@ func (bc *basicController) DeleteRegistration(registrationUUID string) (*scanner
}
// SetDefaultRegistration ...
func (bc *basicController) SetDefaultRegistration(registrationUUID string) error {
return bc.manager.SetAsDefault(registrationUUID)
func (bc *basicController) SetDefaultRegistration(ctx context.Context, registrationUUID string) error {
return bc.manager.SetAsDefault(ctx, registrationUUID)
}
// SetRegistrationByProject ...
@ -204,7 +204,7 @@ func (bc *basicController) GetRegistrationByProject(ctx context.Context, project
var registration *scanner.Registration
if len(m) > 0 {
if registrationID, ok := m[proScannerMetaKey]; ok && len(registrationID) > 0 {
registration, err = bc.manager.Get(registrationID)
registration, err = bc.manager.Get(ctx, registrationID)
if err != nil {
return nil, errors.Wrap(err, "api controller: get project scanner")
}
@ -221,7 +221,7 @@ func (bc *basicController) GetRegistrationByProject(ctx context.Context, project
if registration == nil {
// Second, get the default one
registration, err = bc.manager.GetDefault()
registration, err = bc.manager.GetDefault(ctx)
if err != nil {
return nil, errors.Wrap(err, "api controller: get project scanner")
}
@ -236,7 +236,7 @@ func (bc *basicController) GetRegistrationByProject(ctx context.Context, project
if opts.Ping {
// Get metadata of the configured registration
meta, err := bc.Ping(registration)
meta, err := bc.Ping(ctx, registration)
if err != nil {
// Not blocked, just logged it
log.Error(errors.Wrap(err, "api controller: get project scanner"))
@ -256,7 +256,7 @@ func (bc *basicController) GetRegistrationByProject(ctx context.Context, project
}
// Ping ...
func (bc *basicController) Ping(registration *scanner.Registration) (*v1.ScannerAdapterMetadata, error) {
func (bc *basicController) Ping(ctx context.Context, registration *scanner.Registration) (*v1.ScannerAdapterMetadata, error) {
if registration == nil {
return nil, errors.New("nil registration to ping")
}
@ -314,17 +314,17 @@ func (bc *basicController) Ping(registration *scanner.Registration) (*v1.Scanner
}
// GetMetadata ...
func (bc *basicController) GetMetadata(registrationUUID string) (*v1.ScannerAdapterMetadata, error) {
func (bc *basicController) GetMetadata(ctx context.Context, registrationUUID string) (*v1.ScannerAdapterMetadata, error) {
if len(registrationUUID) == 0 {
return nil, errors.New("empty registration uuid")
}
r, err := bc.manager.Get(registrationUUID)
r, err := bc.manager.Get(ctx, registrationUUID)
if err != nil {
return nil, errors.Wrap(err, "scanner controller: get metadata")
}
return bc.Ping(r)
return bc.Ping(ctx, r)
}
var (

View File

@ -27,6 +27,7 @@ import (
v1testing "github.com/goharbor/harbor/src/testing/pkg/scan/rest/v1"
scannertesting "github.com/goharbor/harbor/src/testing/pkg/scan/scanner"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
)
@ -107,18 +108,18 @@ func (suite *ControllerTestSuite) TestListRegistrations() {
suite.sample.UUID = "uuid"
l := []*scanner.Registration{suite.sample}
suite.mMgr.On("List", query).Return(l, nil)
suite.mMgr.On("List", mock.Anything, query).Return(l, nil)
rl, err := suite.c.ListRegistrations(query)
rl, err := suite.c.ListRegistrations(context.TODO(), query)
require.NoError(suite.T(), err)
assert.Equal(suite.T(), 1, len(rl))
}
// TestCreateRegistration tests CreateRegistration
func (suite *ControllerTestSuite) TestCreateRegistration() {
suite.mMgr.On("Create", suite.sample).Return("uuid", nil)
suite.mMgr.On("Create", mock.Anything, suite.sample).Return("uuid", nil)
uid, err := suite.mMgr.Create(suite.sample)
uid, err := suite.mMgr.Create(context.TODO(), suite.sample)
require.NoError(suite.T(), err)
assert.Equal(suite.T(), uid, "uuid")
@ -127,9 +128,9 @@ func (suite *ControllerTestSuite) TestCreateRegistration() {
// TestGetRegistration tests GetRegistration
func (suite *ControllerTestSuite) TestGetRegistration() {
suite.sample.UUID = "uuid"
suite.mMgr.On("Get", "uuid").Return(suite.sample, nil)
suite.mMgr.On("Get", mock.Anything, "uuid").Return(suite.sample, nil)
rr, err := suite.c.GetRegistration("uuid")
rr, err := suite.c.GetRegistration(context.TODO(), "uuid")
require.NoError(suite.T(), err)
assert.NotNil(suite.T(), rr)
assert.Equal(suite.T(), "forUT", rr.Name)
@ -138,33 +139,33 @@ func (suite *ControllerTestSuite) TestGetRegistration() {
// TestRegistrationExists tests RegistrationExists
func (suite *ControllerTestSuite) TestRegistrationExists() {
suite.sample.UUID = "uuid"
suite.mMgr.On("Get", "uuid").Return(suite.sample, nil)
suite.mMgr.On("Get", mock.Anything, "uuid").Return(suite.sample, nil)
exists := suite.c.RegistrationExists("uuid")
exists := suite.c.RegistrationExists(context.TODO(), "uuid")
assert.Equal(suite.T(), true, exists)
suite.mMgr.On("Get", "uuid2").Return(nil, nil)
suite.mMgr.On("Get", mock.Anything, "uuid2").Return(nil, nil)
exists = suite.c.RegistrationExists("uuid2")
exists = suite.c.RegistrationExists(context.TODO(), "uuid2")
assert.Equal(suite.T(), false, exists)
}
// TestUpdateRegistration tests UpdateRegistration
func (suite *ControllerTestSuite) TestUpdateRegistration() {
suite.sample.UUID = "uuid"
suite.mMgr.On("Update", suite.sample).Return(nil)
suite.mMgr.On("Update", mock.Anything, suite.sample).Return(nil)
err := suite.c.UpdateRegistration(suite.sample)
err := suite.c.UpdateRegistration(context.TODO(), suite.sample)
require.NoError(suite.T(), err)
}
// TestDeleteRegistration tests DeleteRegistration
func (suite *ControllerTestSuite) TestDeleteRegistration() {
suite.sample.UUID = "uuid"
suite.mMgr.On("Get", "uuid").Return(suite.sample, nil)
suite.mMgr.On("Delete", "uuid").Return(nil)
suite.mMgr.On("Get", mock.Anything, "uuid").Return(suite.sample, nil)
suite.mMgr.On("Delete", mock.Anything, "uuid").Return(nil)
r, err := suite.c.DeleteRegistration("uuid")
r, err := suite.c.DeleteRegistration(context.TODO(), "uuid")
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
assert.Equal(suite.T(), "forUT", r.Name)
@ -172,9 +173,9 @@ func (suite *ControllerTestSuite) TestDeleteRegistration() {
// TestSetDefaultRegistration tests SetDefaultRegistration
func (suite *ControllerTestSuite) TestSetDefaultRegistration() {
suite.mMgr.On("SetAsDefault", "uuid").Return(nil)
suite.mMgr.On("SetAsDefault", mock.Anything, "uuid").Return(nil)
err := suite.c.SetDefaultRegistration("uuid")
err := suite.c.SetDefaultRegistration(context.TODO(), "uuid")
require.NoError(suite.T(), err)
}
@ -189,15 +190,15 @@ func (suite *ControllerTestSuite) TestSetRegistrationByProject() {
var pid, pid2 int64 = 1, 2
// not set before
suite.mMeta.On("Get", context.TODO(), pid, proScannerMetaKey).Return(m, nil)
suite.mMeta.On("Add", context.TODO(), pid, mm).Return(nil)
suite.mMeta.On("Get", mock.Anything, pid, proScannerMetaKey).Return(m, nil)
suite.mMeta.On("Add", mock.Anything, pid, mm).Return(nil)
err := suite.c.SetRegistrationByProject(context.TODO(), pid, "uuid")
require.NoError(suite.T(), err)
// Set before
suite.mMeta.On("Get", context.TODO(), pid2, proScannerMetaKey).Return(mm, nil)
suite.mMeta.On("Update", context.TODO(), pid2, mmm).Return(nil)
suite.mMeta.On("Get", mock.Anything, pid2, proScannerMetaKey).Return(mm, nil)
suite.mMeta.On("Update", mock.Anything, pid2, mmm).Return(nil)
err = suite.c.SetRegistrationByProject(context.TODO(), pid2, "uuid2")
require.NoError(suite.T(), err)
@ -212,16 +213,16 @@ func (suite *ControllerTestSuite) TestGetRegistrationByProject() {
var pid int64 = 1
suite.sample.UUID = "uuid"
suite.mMeta.On("Get", context.TODO(), pid, proScannerMetaKey).Return(m, nil)
suite.mMgr.On("Get", "uuid").Return(suite.sample, nil)
suite.mMeta.On("Get", mock.Anything, pid, proScannerMetaKey).Return(m, nil)
suite.mMgr.On("Get", mock.Anything, "uuid").Return(suite.sample, nil)
r, err := suite.c.GetRegistrationByProject(context.TODO(), pid)
require.NoError(suite.T(), err)
require.Equal(suite.T(), "forUT", r.Name)
// Not configured at project level, return system default
suite.mMeta.On("Get", context.TODO(), pid, proScannerMetaKey).Return(nil, nil)
suite.mMgr.On("GetDefault").Return(suite.sample, nil)
suite.mMeta.On("Get", mock.Anything, pid, proScannerMetaKey).Return(nil, nil)
suite.mMgr.On("GetDefault", mock.Anything).Return(suite.sample, nil)
r, err = suite.c.GetRegistrationByProject(context.TODO(), pid)
require.NoError(suite.T(), err)
@ -238,8 +239,8 @@ func (suite *ControllerTestSuite) TestGetRegistrationByProjectWhenPingError() {
var pid int64 = 1
suite.sample.UUID = "uuid"
suite.mMeta.On("Get", context.TODO(), pid, proScannerMetaKey).Return(m, nil)
suite.mMgr.On("Get", "uuid").Return(suite.sample, nil)
suite.mMeta.On("Get", mock.Anything, pid, proScannerMetaKey).Return(m, nil)
suite.mMgr.On("Get", mock.Anything, "uuid").Return(suite.sample, nil)
// Ping error
mc := &v1testing.Client{}
@ -256,7 +257,7 @@ func (suite *ControllerTestSuite) TestGetRegistrationByProjectWhenPingError() {
// TestPing ...
func (suite *ControllerTestSuite) TestPing() {
meta, err := suite.c.Ping(suite.sample)
meta, err := suite.c.Ping(context.TODO(), suite.sample)
require.NoError(suite.T(), err)
suite.NotNil(meta)
}
@ -293,7 +294,7 @@ func (suite *ControllerTestSuite) TestPingWithGenericMimeType() {
proMetaMgr: suite.mMeta,
clientPool: mcp,
}
meta, err := suite.c.Ping(suite.sample)
meta, err := suite.c.Ping(context.TODO(), suite.sample)
require.NoError(suite.T(), err)
suite.NotNil(meta)
}
@ -301,9 +302,9 @@ func (suite *ControllerTestSuite) TestPingWithGenericMimeType() {
// TestGetMetadata ...
func (suite *ControllerTestSuite) TestGetMetadata() {
suite.sample.UUID = "uuid"
suite.mMgr.On("Get", "uuid").Return(suite.sample, nil)
suite.mMgr.On("Get", mock.Anything, "uuid").Return(suite.sample, nil)
meta, err := suite.c.GetMetadata(suite.sample.UUID)
meta, err := suite.c.GetMetadata(context.TODO(), suite.sample.UUID)
require.NoError(suite.T(), err)
suite.NotNil(meta)
suite.Equal(1, len(meta.Capabilities))

View File

@ -29,71 +29,78 @@ type Controller interface {
// Query parameters are optional
//
// Arguments:
// ctx context.Context : the context for this method
// query *q.Query : query parameters
//
// Returns:
// []*scanner.Registration : scanner list of all the matched ones
// error : non nil error if any errors occurred
ListRegistrations(query *q.Query) ([]*scanner.Registration, error)
ListRegistrations(ctx context.Context, query *q.Query) ([]*scanner.Registration, error)
// CreateRegistration creates a new scanner registration with the given data.
// Returns the scanner registration identifier.
//
// Arguments:
// ctx context.Context : the context for this method
// registration *scanner.Registration : scanner registration to create
//
// Returns:
// string : the generated UUID of the new scanner
// error : non nil error if any errors occurred
CreateRegistration(registration *scanner.Registration) (string, error)
CreateRegistration(ctx context.Context, registration *scanner.Registration) (string, error)
// GetRegistration returns the details of the specified scanner registration.
//
// Arguments:
// ctx context.Context : the context for this method
// registrationUUID string : the UUID of the given scanner
//
// Returns:
// *scanner.Registration : the required scanner
// error : non nil error if any errors occurred
GetRegistration(registrationUUID string) (*scanner.Registration, error)
GetRegistration(ctx context.Context, registrationUUID string) (*scanner.Registration, error)
// RegistrationExists checks if the provided registration is there.
//
// Arguments:
// ctx context.Context : the context for this method
// registrationUUID string : the UUID of the given scanner
//
// Returns:
// true for existing or false for not existing
RegistrationExists(registrationUUID string) bool
RegistrationExists(ctx context.Context, registrationUUID string) bool
// UpdateRegistration updates the specified scanner registration.
//
// Arguments:
// ctx context.Context : the context for this method
// registration *scanner.Registration : scanner registration to update
//
// Returns:
// error : non nil error if any errors occurred
UpdateRegistration(registration *scanner.Registration) error
UpdateRegistration(ctx context.Context, registration *scanner.Registration) error
// DeleteRegistration deletes the specified scanner registration.
//
// Arguments:
// ctx context.Context : the context for this method
// registrationUUID string : the UUID of the given scanner which is going to be deleted
//
// Returns:
// *scanner.Registration : the deleted scanner
// error : non nil error if any errors occurred
DeleteRegistration(registrationUUID string) (*scanner.Registration, error)
DeleteRegistration(ctx context.Context, registrationUUID string) (*scanner.Registration, error)
// SetDefaultRegistration marks the specified scanner registration as default.
// The implementation is supposed to unset any registration previously set as default.
//
// Arguments:
// ctx context.Context : the context for this method
// registrationUUID string : the UUID of the given scanner which is marked as default
//
// Returns:
// error : non nil error if any errors occurred
SetDefaultRegistration(registrationUUID string) error
SetDefaultRegistration(ctx context.Context, registrationUUID string) error
// SetRegistrationByProject sets scanner for the given project.
//
@ -123,20 +130,22 @@ type Controller interface {
// Returns `nil` if connection succeeded, a non `nil` error otherwise.
//
// Arguments:
// ctx context.Context : the context for this method
// registration *scanner.Registration : scanner registration to ping
//
// Returns:
// *v1.ScannerAdapterMetadata : metadata returned by the scanner if successfully ping
// error : non nil error if any errors occurred
Ping(registration *scanner.Registration) (*v1.ScannerAdapterMetadata, error)
Ping(ctx context.Context, registration *scanner.Registration) (*v1.ScannerAdapterMetadata, error)
// GetMetadata returns the metadata of the given scanner.
//
// Arguments:
// ctx context.Context : the context for this method
// registrationUUID string : the UUID of the given scanner which is marked as default
//
// Returns:
// *v1.ScannerAdapterMetadata : metadata returned by the scanner if successfully ping
// error : non nil error if any errors occurred
GetMetadata(registrationUUID string) (*v1.ScannerAdapterMetadata, error)
GetMetadata(ctx context.Context, registrationUUID string) (*v1.ScannerAdapterMetadata, error)
}

View File

@ -131,7 +131,7 @@ func (sa *ProjectScannerAPI) GetProScannerCandidates() {
PageNumber: p,
}
all, err := sa.c.ListRegistrations(query)
all, err := sa.c.ListRegistrations(sa.Context(), query)
if err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: get project scanner candidates"))
return

View File

@ -111,7 +111,7 @@ func (suite *ProScannerAPITestSuite) TestScannerAPIGetScannerCandidates() {
Description: "JUST FOR TEST",
URL: "https://a.b.c",
}}
suite.mockC.On("ListRegistrations", query).Return(ll, nil)
suite.mockC.On("ListRegistrations", mock.Anything, query).Return(ll, nil)
// Get
l := make([]*scanner.Registration, 0)

View File

@ -16,10 +16,11 @@ package api
import (
"fmt"
"net/http"
"github.com/goharbor/harbor/src/common/rbac"
"github.com/goharbor/harbor/src/common/rbac/system"
"github.com/goharbor/harbor/src/pkg/permission/types"
"net/http"
s "github.com/goharbor/harbor/src/controller/scanner"
"github.com/goharbor/harbor/src/lib/errors"
@ -76,7 +77,7 @@ func (sa *ScannerAPI) Metadata() {
}
uuid := sa.GetStringFromPath(":uuid")
meta, err := sa.c.GetMetadata(uuid)
meta, err := sa.c.GetMetadata(sa.Context(), uuid)
if err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: get metadata"))
return
@ -118,7 +119,7 @@ func (sa *ScannerAPI) List() {
query.Keywords = kws
}
all, err := sa.c.ListRegistrations(query)
all, err := sa.c.ListRegistrations(sa.Context(), query)
if err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: list all"))
return
@ -156,7 +157,7 @@ func (sa *ScannerAPI) Create() {
// All newly created should be non default one except the 1st one
r.IsDefault = false
uuid, err := sa.c.CreateRegistration(r)
uuid, err := sa.c.CreateRegistration(sa.Context(), r)
if err != nil {
sa.SendError(errors.Wrap(err, "scanner API: create"))
return
@ -220,7 +221,7 @@ func (sa *ScannerAPI) Update() {
getChanges(r, rr)
if err := sa.c.UpdateRegistration(r); err != nil {
if err := sa.c.UpdateRegistration(sa.Context(), r); err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: update"))
return
}
@ -251,7 +252,7 @@ func (sa *ScannerAPI) Delete() {
return
}
deleted, err := sa.c.DeleteRegistration(r.UUID)
deleted, err := sa.c.DeleteRegistration(sa.Context(), r.UUID)
if err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: delete"))
return
@ -277,7 +278,7 @@ func (sa *ScannerAPI) SetAsDefault() {
if v, ok := m["is_default"]; ok {
if isDefault, y := v.(bool); y && isDefault {
if err := sa.c.SetDefaultRegistration(uid); err != nil {
if err := sa.c.SetDefaultRegistration(sa.Context(), uid); err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: set as default"))
}
@ -307,7 +308,7 @@ func (sa *ScannerAPI) Ping() {
return
}
if _, err := sa.c.Ping(r); err != nil {
if _, err := sa.c.Ping(sa.Context(), r); err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: ping"))
return
}
@ -317,7 +318,7 @@ func (sa *ScannerAPI) Ping() {
func (sa *ScannerAPI) get() *scanner.Registration {
uid := sa.GetStringFromPath(":uuid")
r, err := sa.c.GetRegistration(uid)
r, err := sa.c.GetRegistration(sa.Context(), uid)
if err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: get"))
return nil
@ -341,7 +342,7 @@ func (sa *ScannerAPI) checkDuplicated(property, value string) bool {
Keywords: kw,
}
l, err := sa.c.ListRegistrations(query)
l, err := sa.c.ListRegistrations(sa.Context(), query)
if err != nil {
sa.SendInternalServerError(errors.Wrap(err, "scanner API: check existence"))
return false

View File

@ -23,6 +23,7 @@ import (
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
scannertesting "github.com/goharbor/harbor/src/testing/controller/scanner"
"github.com/goharbor/harbor/src/testing/mock"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@ -107,7 +108,7 @@ func (suite *ScannerAPITestSuite) TestScannerAPIGet() {
Description: "JUST FOR TEST",
URL: "https://a.b.c",
}
suite.mockC.On("GetRegistration", "uuid").Return(res, nil)
suite.mockC.On("GetRegistration", mock.Anything, "uuid").Return(res, nil)
// Get
rr := &scanner.Registration{}
@ -131,7 +132,7 @@ func (suite *ScannerAPITestSuite) TestScannerAPICreate() {
}
suite.mockQuery(r)
suite.mockC.On("CreateRegistration", r).Return("uuid", nil)
suite.mockC.On("CreateRegistration", mock.Anything, r).Return("uuid", nil)
// Create
res := make(map[string]string, 1)
@ -163,7 +164,7 @@ func (suite *ScannerAPITestSuite) TestScannerAPIList() {
Description: "JUST FOR TEST",
URL: "https://a.b.c",
}}
suite.mockC.On("ListRegistrations", query).Return(ll, nil)
suite.mockC.On("ListRegistrations", mock.Anything, query).Return(ll, nil)
// List
l := make([]*scanner.Registration, 0)
@ -198,8 +199,8 @@ func (suite *ScannerAPITestSuite) TestScannerAPIUpdate() {
}
suite.mockQuery(updated)
suite.mockC.On("UpdateRegistration", updated).Return(nil)
suite.mockC.On("GetRegistration", "uuid").Return(before, nil)
suite.mockC.On("UpdateRegistration", mock.Anything, updated).Return(nil)
suite.mockC.On("GetRegistration", mock.Anything, "uuid").Return(before, nil)
rr := &scanner.Registration{}
err := handleAndParse(&testingRequest{
@ -225,8 +226,8 @@ func (suite *ScannerAPITestSuite) TestScannerAPIDelete() {
URL: "https://a.b.c",
}
suite.mockC.On("GetRegistration", "uuid").Return(r, nil)
suite.mockC.On("DeleteRegistration", "uuid").Return(r, nil)
suite.mockC.On("GetRegistration", mock.Anything, "uuid").Return(r, nil)
suite.mockC.On("DeleteRegistration", mock.Anything, "uuid").Return(r, nil)
deleted := &scanner.Registration{}
err := handleAndParse(&testingRequest{
@ -242,7 +243,7 @@ func (suite *ScannerAPITestSuite) TestScannerAPIDelete() {
// TestScannerAPISetDefault tests the set default
func (suite *ScannerAPITestSuite) TestScannerAPISetDefault() {
suite.mockC.On("SetDefaultRegistration", "uuid").Return(nil)
suite.mockC.On("SetDefaultRegistration", mock.Anything, "uuid").Return(nil)
body := make(map[string]interface{}, 1)
body["is_default"] = true
@ -264,12 +265,12 @@ func (suite *ScannerAPITestSuite) mockQuery(r *scanner.Registration) {
Keywords: kw,
}
emptyL := make([]*scanner.Registration, 0)
suite.mockC.On("ListRegistrations", query).Return(emptyL, nil)
suite.mockC.On("ListRegistrations", mock.Anything, query).Return(emptyL, nil)
kw2 := make(map[string]interface{}, 1)
kw2["url"] = r.URL
query2 := &q.Query{
Keywords: kw2,
}
suite.mockC.On("ListRegistrations", query2).Return(emptyL, nil)
suite.mockC.On("ListRegistrations", mock.Anything, query2).Return(emptyL, nil)
}

View File

@ -15,6 +15,7 @@
package main
import (
"context"
"encoding/gob"
"fmt"
"net/url"
@ -47,6 +48,7 @@ import (
_ "github.com/goharbor/harbor/src/lib/cache/redis" // redis cache
"github.com/goharbor/harbor/src/lib/log"
"github.com/goharbor/harbor/src/lib/metric"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/migration"
"github.com/goharbor/harbor/src/pkg/notification"
_ "github.com/goharbor/harbor/src/pkg/notifier/topic"
@ -204,7 +206,7 @@ func main() {
log.Fatalf("Failed to initialize API handlers with error: %s", err.Error())
}
registerScanners()
registerScanners(orm.Context())
closing := make(chan struct{})
done := make(chan struct{})
@ -240,7 +242,7 @@ const (
trivyScanner = "Trivy"
)
func registerScanners() {
func registerScanners(ctx context.Context) {
wantedScanners := make([]scanner.Registration, 0)
uninstallScannerNames := make([]string, 0)
@ -258,17 +260,17 @@ func registerScanners() {
uninstallScannerNames = append(uninstallScannerNames, trivyScanner)
}
if err := scan.RemoveImmutableScanners(uninstallScannerNames); err != nil {
if err := scan.RemoveImmutableScanners(ctx, uninstallScannerNames); err != nil {
log.Warningf("failed to remove scanners: %v", err)
}
if err := scan.EnsureScanners(wantedScanners); err != nil {
if err := scan.EnsureScanners(ctx, wantedScanners); err != nil {
log.Fatalf("failed to register scanners: %v", err)
}
if defaultScannerName := getDefaultScannerName(); defaultScannerName != "" {
log.Infof("Setting %s as default scanner", defaultScannerName)
if err := scan.EnsureDefaultScanner(defaultScannerName); err != nil {
if err := scan.EnsureDefaultScanner(ctx, defaultScannerName); err != nil {
log.Fatalf("failed to set default scanner: %v", err)
}
}

View File

@ -20,6 +20,11 @@ import (
"github.com/lib/pq"
)
var (
// ErrNoRows error from the beego orm
ErrNoRows = orm.ErrNoRows
)
// WrapNotFoundError wrap error as NotFoundError when it is orm.ErrNoRows otherwise return err
func WrapNotFoundError(err error, format string, args ...interface{}) error {
if e := AsNotFoundError(err, format, args...); e != nil {

View File

@ -124,10 +124,10 @@ func (suite *VulnerabilityTestSuite) SetupTest() {
// TearDownTest clears enf for test case.
func (suite *VulnerabilityTestSuite) TearDownTest() {
registrations, err := scanner.ListRegistrations(&q.Query{})
registrations, err := scanner.ListRegistrations(suite.Context(), &q.Query{})
require.NoError(suite.T(), err, "Failed to cleanup scanner registrations")
for _, registration := range registrations {
err = scanner.DeleteRegistration(registration.UUID)
err = scanner.DeleteRegistration(suite.Context(), registration.UUID)
require.NoError(suite.T(), err, "Error when cleaning up scanner registrations")
}
reports, err := suite.dao.List(orm.Context(), &q.Query{})
@ -280,7 +280,7 @@ func (suite *VulnerabilityTestSuite) registerScanner(registrationUUID string) {
URL: fmt.Sprintf("https://sample.scanner.com/%s", registrationUUID),
}
_, err := scanner.AddRegistration(r)
_, err := scanner.AddRegistration(suite.Context(), r)
require.NoError(suite.T(), err, "add new registration")
}

View File

@ -15,13 +15,12 @@
package scanner
import (
"context"
"fmt"
"strings"
"github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/common/dao"
"github.com/goharbor/harbor/src/lib/errors"
liborm "github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
)
@ -30,26 +29,32 @@ func init() {
}
// AddRegistration adds a new registration
func AddRegistration(r *Registration) (int64, error) {
o := dao.GetOrmer()
func AddRegistration(ctx context.Context, r *Registration) (int64, error) {
o, err := orm.FromContext(ctx)
if err != nil {
return 0, err
}
id, err := o.Insert(r)
if err != nil {
return 0, liborm.WrapConflictError(err, "registration name or url already exists")
return 0, orm.WrapConflictError(err, "registration name or url already exists")
}
return id, nil
}
// GetRegistration gets the specified registration
func GetRegistration(UUID string) (*Registration, error) {
e := &Registration{}
func GetRegistration(ctx context.Context, UUID string) (*Registration, error) {
o, err := orm.FromContext(ctx)
if err != nil {
return nil, err
}
o := dao.GetOrmer()
e := &Registration{}
qs := o.QueryTable(new(Registration))
if err := qs.Filter("uuid", UUID).One(e); err != nil {
if err == orm.ErrNoRows {
if errors.Is(err, orm.ErrNoRows) {
// Not existing case
return nil, nil
}
@ -60,8 +65,12 @@ func GetRegistration(UUID string) (*Registration, error) {
}
// UpdateRegistration update the specified registration
func UpdateRegistration(r *Registration, cols ...string) error {
o := dao.GetOrmer()
func UpdateRegistration(ctx context.Context, r *Registration, cols ...string) error {
o, err := orm.FromContext(ctx)
if err != nil {
return err
}
count, err := o.Update(r, cols...)
if err != nil {
return err
@ -75,8 +84,12 @@ func UpdateRegistration(r *Registration, cols ...string) error {
}
// DeleteRegistration deletes the registration with the specified UUID
func DeleteRegistration(UUID string) error {
o := dao.GetOrmer()
func DeleteRegistration(ctx context.Context, UUID string) error {
o, err := orm.FromContext(ctx)
if err != nil {
return err
}
qt := o.QueryTable(new(Registration))
// delete with query way
@ -94,8 +107,12 @@ func DeleteRegistration(UUID string) error {
}
// ListRegistrations lists all the existing registrations
func ListRegistrations(query *q.Query) ([]*Registration, error) {
o := dao.GetOrmer()
func ListRegistrations(ctx context.Context, query *q.Query) ([]*Registration, error) {
o, err := orm.FromContext(ctx)
if err != nil {
return nil, err
}
qt := o.QueryTable(new(Registration))
if query != nil {
@ -107,7 +124,7 @@ func ListRegistrations(query *q.Query) ([]*Registration, error) {
continue
}
if s, ok := v.(string); ok {
v = liborm.Escape(s)
v = orm.Escape(s)
}
qt = qt.Filter(fmt.Sprintf("%s__icontains", k), v)
@ -123,58 +140,58 @@ func ListRegistrations(query *q.Query) ([]*Registration, error) {
qt = qt.OrderBy("-is_default", "-create_time")
l := make([]*Registration, 0)
_, err := qt.All(&l)
_, err = qt.All(&l)
return l, err
}
// SetDefaultRegistration sets the specified registration as default one
func SetDefaultRegistration(UUID string) error {
o := orm.NewOrm()
err := o.Begin()
if err != nil {
return err
}
func SetDefaultRegistration(ctx context.Context, UUID string) error {
f := func(ctx context.Context) error {
o, err := orm.FromContext(ctx)
if err != nil {
return err
}
var count int64
qt := o.QueryTable(new(Registration))
count, err = qt.Filter("uuid", UUID).
Filter("disabled", false).
Update(orm.Params{
"is_default": true,
})
if err == nil && count == 0 {
err = errors.Errorf("set default for %s failed", UUID)
}
var count int64
qt := o.QueryTable(new(Registration))
count, err = qt.Filter("uuid", UUID).
Filter("disabled", false).
Update(orm.Params{
"is_default": true,
})
if err != nil {
return err
}
if count == 0 {
return errors.Errorf("set default for %s failed", UUID)
}
if err == nil {
qt2 := o.QueryTable(new(Registration))
_, err = qt2.Exclude("uuid__exact", UUID).
Filter("is_default", true).
Update(orm.Params{
"is_default": false,
})
return err
}
if err != nil {
if e := o.Rollback(); e != nil {
err = errors.Wrap(e, err.Error())
}
} else {
err = o.Commit()
}
return err
return orm.WithTransaction(f)(ctx)
}
// GetDefaultRegistration gets the default registration
func GetDefaultRegistration() (*Registration, error) {
o := dao.GetOrmer()
func GetDefaultRegistration(ctx context.Context) (*Registration, error) {
o, err := orm.FromContext(ctx)
if err != nil {
return nil, err
}
qt := o.QueryTable(new(Registration))
e := &Registration{}
if err := qt.Filter("is_default", true).One(e); err != nil {
if err == orm.ErrNoRows {
if errors.Is(err, orm.ErrNoRows) {
return nil, nil
}

View File

@ -17,8 +17,8 @@ package scanner
import (
"testing"
"github.com/goharbor/harbor/src/common/dao"
"github.com/goharbor/harbor/src/lib/q"
htesting "github.com/goharbor/harbor/src/testing"
"github.com/google/uuid"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@ -27,7 +27,7 @@ import (
// RegistrationDAOTestSuite is test suite of testing registration DAO
type RegistrationDAOTestSuite struct {
suite.Suite
htesting.Suite
registrationID string
}
@ -39,7 +39,7 @@ func TestRegistrationDAO(t *testing.T) {
// SetupSuite prepare testing env for the suite
func (suite *RegistrationDAOTestSuite) SetupSuite() {
dao.PrepareTestForPostgresSQL()
suite.Suite.SetupSuite()
}
// SetupTest prepare stuff for test cases
@ -52,34 +52,34 @@ func (suite *RegistrationDAOTestSuite) SetupTest() {
URL: "https://sample.scanner.com",
}
_, err := AddRegistration(r)
_, err := AddRegistration(suite.Context(), r)
require.NoError(suite.T(), err, "add new registration")
}
// TearDownTest clears all the stuff of test cases
func (suite *RegistrationDAOTestSuite) TearDownTest() {
err := DeleteRegistration(suite.registrationID)
err := DeleteRegistration(suite.Context(), suite.registrationID)
require.NoError(suite.T(), err, "clear registration")
}
// TestGet tests get registration
func (suite *RegistrationDAOTestSuite) TestGet() {
// Found
r, err := GetRegistration(suite.registrationID)
r, err := GetRegistration(suite.Context(), suite.registrationID)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
assert.Equal(suite.T(), r.Name, "forUT")
// Not found
re, err := GetRegistration("not_found")
re, err := GetRegistration(suite.Context(), "not_found")
require.NoError(suite.T(), err)
require.Nil(suite.T(), re)
}
// TestUpdate tests update registration
func (suite *RegistrationDAOTestSuite) TestUpdate() {
r, err := GetRegistration(suite.registrationID)
r, err := GetRegistration(suite.Context(), suite.registrationID)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
@ -87,10 +87,10 @@ func (suite *RegistrationDAOTestSuite) TestUpdate() {
r.IsDefault = true
r.URL = "http://updated.registration.com"
err = UpdateRegistration(r)
err = UpdateRegistration(suite.Context(), r)
require.NoError(suite.T(), err, "update registration")
r, err = GetRegistration(suite.registrationID)
r, err = GetRegistration(suite.Context(), suite.registrationID)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
@ -102,14 +102,14 @@ func (suite *RegistrationDAOTestSuite) TestUpdate() {
// TestList tests list registrations
func (suite *RegistrationDAOTestSuite) TestList() {
// no query
l, err := ListRegistrations(nil)
l, err := ListRegistrations(suite.Context(), nil)
require.NoError(suite.T(), err)
require.Equal(suite.T(), 1, len(l))
// with query and found items
keywords := make(map[string]interface{})
keywords["description"] = "sample"
l, err = ListRegistrations(&q.Query{
l, err = ListRegistrations(suite.Context(), &q.Query{
PageSize: 5,
PageNumber: 1,
Keywords: keywords,
@ -119,7 +119,7 @@ func (suite *RegistrationDAOTestSuite) TestList() {
// With query and not found items
keywords["description"] = "not_exist"
l, err = ListRegistrations(&q.Query{
l, err = ListRegistrations(suite.Context(), &q.Query{
Keywords: keywords,
})
require.NoError(suite.T(), err)
@ -128,14 +128,14 @@ func (suite *RegistrationDAOTestSuite) TestList() {
// Exact match
exactKeywords := make(map[string]interface{})
exactKeywords["ex_name"] = "forUT"
l, err = ListRegistrations(&q.Query{
l, err = ListRegistrations(suite.Context(), &q.Query{
Keywords: exactKeywords,
})
require.NoError(suite.T(), err)
require.Equal(suite.T(), 1, len(l))
exactKeywords["ex_name"] = "forU"
l, err = ListRegistrations(&q.Query{
l, err = ListRegistrations(suite.Context(), &q.Query{
Keywords: exactKeywords,
})
require.NoError(suite.T(), err)
@ -144,21 +144,21 @@ func (suite *RegistrationDAOTestSuite) TestList() {
// TestDefault tests set/get default
func (suite *RegistrationDAOTestSuite) TestDefault() {
dr, err := GetDefaultRegistration()
dr, err := GetDefaultRegistration(suite.Context())
require.NoError(suite.T(), err, "not found")
require.Nil(suite.T(), dr)
err = SetDefaultRegistration(suite.registrationID)
err = SetDefaultRegistration(suite.Context(), suite.registrationID)
require.NoError(suite.T(), err)
dr, err = GetDefaultRegistration()
dr, err = GetDefaultRegistration(suite.Context())
require.NoError(suite.T(), err)
require.NotNil(suite.T(), dr)
dr.Disabled = true
err = UpdateRegistration(dr, "disabled")
err = UpdateRegistration(suite.Context(), dr, "disabled")
require.NoError(suite.T(), err)
err = SetDefaultRegistration(suite.registrationID)
err = SetDefaultRegistration(suite.Context(), suite.registrationID)
require.Error(suite.T(), err)
}

View File

@ -15,6 +15,8 @@
package scan
import (
"context"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/log"
"github.com/goharbor/harbor/src/lib/q"
@ -27,7 +29,7 @@ var (
)
// EnsureScanners ensures that the scanners with the specified endpoints URLs exist in the system.
func EnsureScanners(wantedScanners []scanner.Registration) (err error) {
func EnsureScanners(ctx context.Context, wantedScanners []scanner.Registration) (err error) {
if len(wantedScanners) == 0 {
return
}
@ -36,7 +38,7 @@ func EnsureScanners(wantedScanners []scanner.Registration) (err error) {
names[i] = ws.Name
}
list, err := scannerManager.List(q.New(q.KeyWords{"ex_name__in": names}))
list, err := scannerManager.List(ctx, q.New(q.KeyWords{"ex_name__in": names}))
if err != nil {
return errors.Errorf("listing scanners: %v", err)
}
@ -48,13 +50,13 @@ func EnsureScanners(wantedScanners []scanner.Registration) (err error) {
for _, ws := range wantedScanners {
scanner, exists := existingScanners[ws.Name]
if !exists {
if _, err := scannerManager.Create(&ws); err != nil {
if _, err := scannerManager.Create(ctx, &ws); err != nil {
return errors.Errorf("creating registration %s at %s failed: %v", ws.Name, ws.URL, err)
}
log.Infof("Successfully registered %s scanner at %s", ws.Name, ws.URL)
} else if scanner.URL != ws.URL {
scanner.URL = ws.URL
if err := scannerManager.Update(scanner); err != nil {
if err := scannerManager.Update(ctx, scanner); err != nil {
return errors.Errorf("updating registration %s to %s failed: %v", ws.Name, ws.URL, err)
}
log.Infof("Successfully updated %s scanner to %s", ws.Name, ws.URL)
@ -67,8 +69,8 @@ func EnsureScanners(wantedScanners []scanner.Registration) (err error) {
}
// EnsureDefaultScanner ensures that the scanner with the specified URL is set as default in the system.
func EnsureDefaultScanner(scannerName string) (err error) {
defaultScanner, err := scannerManager.GetDefault()
func EnsureDefaultScanner(ctx context.Context, scannerName string) (err error) {
defaultScanner, err := scannerManager.GetDefault(ctx)
if err != nil {
err = errors.Errorf("getting default scanner: %v", err)
return
@ -77,7 +79,7 @@ func EnsureDefaultScanner(scannerName string) (err error) {
log.Infof("Skipped setting %s as the default scanner. The default scanner is already set to %s", scannerName, defaultScanner.URL)
return
}
scanners, err := scannerManager.List(q.New(q.KeyWords{"ex_name": scannerName}))
scanners, err := scannerManager.List(ctx, q.New(q.KeyWords{"ex_name": scannerName}))
if err != nil {
err = errors.Errorf("listing scanners: %v", err)
return
@ -85,7 +87,7 @@ func EnsureDefaultScanner(scannerName string) (err error) {
if len(scanners) != 1 {
return errors.Errorf("expected only one scanner with name %v but got %d", scannerName, len(scanners))
}
err = scannerManager.SetAsDefault(scanners[0].UUID)
err = scannerManager.SetAsDefault(ctx, scanners[0].UUID)
if err != nil {
err = errors.Errorf("setting %s as default scanner: %v", scannerName, err)
}
@ -93,20 +95,20 @@ func EnsureDefaultScanner(scannerName string) (err error) {
}
// RemoveImmutableScanners removes immutable scanner Registrations with the specified endpoint URLs.
func RemoveImmutableScanners(names []string) error {
func RemoveImmutableScanners(ctx context.Context, names []string) error {
if len(names) == 0 {
return nil
}
query := q.New(q.KeyWords{"ex_immutable": true, "ex_name__in": names})
// TODO Instead of executing 1 to N SQL queries we might want to delete multiple rows with scannerManager.DeleteByImmutableAndURLIn(true, []string{})
registrations, err := scannerManager.List(query)
registrations, err := scannerManager.List(ctx, query)
if err != nil {
return errors.Errorf("listing scanners: %v", err)
}
for _, reg := range registrations {
if err := scannerManager.Delete(reg.UUID); err != nil {
if err := scannerManager.Delete(ctx, reg.UUID); err != nil {
return errors.Errorf("deleting scanner: %s: %v", reg.UUID, err)
}
}

View File

@ -15,19 +15,21 @@
package scan
import (
"context"
"testing"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
"github.com/goharbor/harbor/src/pkg/scan/scanner/mocks"
"github.com/goharbor/harbor/src/testing/mock"
mocks "github.com/goharbor/harbor/src/testing/pkg/scan/scanner"
"github.com/stretchr/testify/assert"
)
func TestEnsureScanners(t *testing.T) {
t.Run("Should do nothing when list of wanted scanners is empty", func(t *testing.T) {
err := EnsureScanners([]scanner.Registration{})
err := EnsureScanners(context.TODO(), []scanner.Registration{})
assert.NoError(t, err)
})
@ -35,13 +37,13 @@ func TestEnsureScanners(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("List", &q.Query{
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{
"ex_name__in": []string{"scanner"},
},
}).Return(nil, errors.New("DB error"))
err := EnsureScanners([]scanner.Registration{
err := EnsureScanners(context.TODO(), []scanner.Registration{
{Name: "scanner", URL: "http://scanner:8080"},
})
@ -53,19 +55,19 @@ func TestEnsureScanners(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("List", &q.Query{
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{
"ex_name__in": []string{
"trivy",
},
},
}).Return([]*scanner.Registration{}, nil)
mgr.On("Create", &scanner.Registration{
mgr.On("Create", mock.Anything, &scanner.Registration{
Name: "trivy",
URL: "http://trivy:8080",
}).Return("uuid-trivy", nil)
err := EnsureScanners([]scanner.Registration{
err := EnsureScanners(context.TODO(), []scanner.Registration{
{Name: "trivy", URL: "http://trivy:8080"},
})
@ -77,7 +79,7 @@ func TestEnsureScanners(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("List", &q.Query{
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{
"ex_name__in": []string{
"trivy",
@ -86,12 +88,12 @@ func TestEnsureScanners(t *testing.T) {
}).Return([]*scanner.Registration{
{Name: "trivy", URL: "http://trivy:8080"},
}, nil)
mgr.On("Update", &scanner.Registration{
mgr.On("Update", mock.Anything, &scanner.Registration{
Name: "trivy",
URL: "http://trivy:8443",
}).Return(nil)
err := EnsureScanners([]scanner.Registration{
err := EnsureScanners(context.TODO(), []scanner.Registration{
{Name: "trivy", URL: "http://trivy:8443"},
})
@ -107,9 +109,9 @@ func TestEnsureDefaultScanner(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("GetDefault").Return(nil, errors.New("DB error"))
mgr.On("GetDefault", mock.Anything).Return(nil, errors.New("DB error"))
err := EnsureDefaultScanner("trivy")
err := EnsureDefaultScanner(context.TODO(), "trivy")
assert.EqualError(t, err, "getting default scanner: DB error")
mgr.AssertExpectations(t)
})
@ -118,11 +120,11 @@ func TestEnsureDefaultScanner(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("GetDefault").Return(&scanner.Registration{
mgr.On("GetDefault", mock.Anything).Return(&scanner.Registration{
Name: "trivy",
}, nil)
err := EnsureDefaultScanner("trivy")
err := EnsureDefaultScanner(context.TODO(), "trivy")
assert.NoError(t, err)
mgr.AssertExpectations(t)
})
@ -131,12 +133,12 @@ func TestEnsureDefaultScanner(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("GetDefault").Return(nil, nil)
mgr.On("List", &q.Query{
mgr.On("GetDefault", mock.Anything).Return(nil, nil)
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{"ex_name": "trivy"},
}).Return(nil, errors.New("DB error"))
err := EnsureDefaultScanner("trivy")
err := EnsureDefaultScanner(context.TODO(), "trivy")
assert.EqualError(t, err, "listing scanners: DB error")
mgr.AssertExpectations(t)
})
@ -145,15 +147,15 @@ func TestEnsureDefaultScanner(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("GetDefault").Return(nil, nil)
mgr.On("List", &q.Query{
mgr.On("GetDefault", mock.Anything).Return(nil, nil)
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{"ex_name": "trivy"},
}).Return([]*scanner.Registration{
{Name: "trivy"},
{Name: "trivy"},
}, nil)
err := EnsureDefaultScanner("trivy")
err := EnsureDefaultScanner(context.TODO(), "trivy")
assert.EqualError(t, err, "expected only one scanner with name trivy but got 2")
mgr.AssertExpectations(t)
})
@ -162,8 +164,8 @@ func TestEnsureDefaultScanner(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("GetDefault").Return(nil, nil)
mgr.On("List", &q.Query{
mgr.On("GetDefault", mock.Anything).Return(nil, nil)
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{"ex_name": "trivy"},
}).Return([]*scanner.Registration{
{
@ -172,9 +174,9 @@ func TestEnsureDefaultScanner(t *testing.T) {
URL: "http://trivy:8080",
},
}, nil)
mgr.On("SetAsDefault", "trivy-uuid").Return(nil)
mgr.On("SetAsDefault", mock.Anything, "trivy-uuid").Return(nil)
err := EnsureDefaultScanner("trivy")
err := EnsureDefaultScanner(context.TODO(), "trivy")
assert.NoError(t, err)
mgr.AssertExpectations(t)
})
@ -183,8 +185,8 @@ func TestEnsureDefaultScanner(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("GetDefault").Return(nil, nil)
mgr.On("List", &q.Query{
mgr.On("GetDefault", mock.Anything).Return(nil, nil)
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{"ex_name": "trivy"},
}).Return([]*scanner.Registration{
{
@ -193,9 +195,9 @@ func TestEnsureDefaultScanner(t *testing.T) {
URL: "http://trivy:8080",
},
}, nil)
mgr.On("SetAsDefault", "trivy-uuid").Return(errors.New("DB error"))
mgr.On("SetAsDefault", mock.Anything, "trivy-uuid").Return(errors.New("DB error"))
err := EnsureDefaultScanner("trivy")
err := EnsureDefaultScanner(context.TODO(), "trivy")
assert.EqualError(t, err, "setting trivy as default scanner: DB error")
mgr.AssertExpectations(t)
})
@ -208,7 +210,7 @@ func TestRemoveImmutableScanners(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
err := RemoveImmutableScanners([]string{})
err := RemoveImmutableScanners(context.TODO(), []string{})
assert.NoError(t, err)
mgr.AssertExpectations(t)
})
@ -217,14 +219,14 @@ func TestRemoveImmutableScanners(t *testing.T) {
mgr := &mocks.Manager{}
scannerManager = mgr
mgr.On("List", &q.Query{
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{
"ex_immutable": true,
"ex_name__in": []string{"scanner"},
},
}).Return(nil, errors.New("DB error"))
err := RemoveImmutableScanners([]string{"scanner"})
err := RemoveImmutableScanners(context.TODO(), []string{"scanner"})
assert.EqualError(t, err, "listing scanners: DB error")
mgr.AssertExpectations(t)
})
@ -245,7 +247,7 @@ func TestRemoveImmutableScanners(t *testing.T) {
URL: "http://scanner-2",
}}
mgr.On("List", &q.Query{
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{
"ex_immutable": true,
"ex_name__in": []string{
@ -254,10 +256,10 @@ func TestRemoveImmutableScanners(t *testing.T) {
},
},
}).Return(registrations, nil)
mgr.On("Delete", "uuid-1").Return(nil)
mgr.On("Delete", "uuid-2").Return(nil)
mgr.On("Delete", mock.Anything, "uuid-1").Return(nil)
mgr.On("Delete", mock.Anything, "uuid-2").Return(nil)
err := RemoveImmutableScanners([]string{
err := RemoveImmutableScanners(context.TODO(), []string{
"scanner-1",
"scanner-2",
})
@ -281,7 +283,7 @@ func TestRemoveImmutableScanners(t *testing.T) {
URL: "http://scanner-2",
}}
mgr.On("List", &q.Query{
mgr.On("List", mock.Anything, &q.Query{
Keywords: map[string]interface{}{
"ex_immutable": true,
"ex_name__in": []string{
@ -290,10 +292,10 @@ func TestRemoveImmutableScanners(t *testing.T) {
},
},
}).Return(registrations, nil)
mgr.On("Delete", "uuid-1").Return(nil)
mgr.On("Delete", "uuid-2").Return(errors.New("DB error"))
mgr.On("Delete", mock.Anything, "uuid-1").Return(nil)
mgr.On("Delete", mock.Anything, "uuid-2").Return(errors.New("DB error"))
err := RemoveImmutableScanners([]string{
err := RemoveImmutableScanners(context.TODO(), []string{
"scanner-1",
"scanner-2",
})

View File

@ -1,7 +1,24 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package postprocessors
import (
"encoding/json"
"testing"
"time"
"github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
@ -14,8 +31,6 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"testing"
"time"
)
const sampleReport = `{
@ -133,7 +148,7 @@ const sampleReportWithCompleteVulnData = `{
"vector_v3": "CVSS:3.0/AV:L/AC:L/PR:L/UI:N/S:U/C:H/I:N/A:N",
"vector_v2": "AV:L/AC:M/Au:N/C:P/I:N/A:N"
},
"vendor_attributes":{
"vendor_attributes":{
"CVSS":{
"nvd" : {
"V2Score": 7.1,
@ -185,7 +200,7 @@ const sampleReportWithMixedSeverity = `{
"vector_v3": "CVSS:3.0/AV:L/AC:L/PR:L/UI:N/S:U/C:H/I:N/A:N",
"vector_v2": "AV:L/AC:M/Au:N/C:P/I:N/A:N"
},
"vendor_attributes":{
"vendor_attributes":{
"CVSS":{
"nvd" : {
"V2Score": 7.1,
@ -220,7 +235,7 @@ const sampleReportWithMixedSeverity = `{
"vector_v3": "CVSS:3.0/AV:L/AC:L/PR:L/UI:N/S:U/C:H/I:N/A:N",
"vector_v2": "AV:L/AC:M/Au:N/C:P/I:N/A:N"
},
"vendor_attributes":{
"vendor_attributes":{
"CVSS":{
"nvd" : {
"V2Score": 7.1,
@ -255,7 +270,7 @@ const sampleReportWithMixedSeverity = `{
"vector_v3": "CVSS:3.0/AV:L/AC:L/PR:L/UI:N/S:U/C:H/I:N/A:N",
"vector_v2": "AV:L/AC:M/Au:N/C:P/I:N/A:N"
},
"vendor_attributes":{
"vendor_attributes":{
"CVSS":{
"nvd" : {
"V2Score": 7.1,
@ -291,7 +306,7 @@ func (suite *TestReportConverterSuite) SetupTest() {
URL: "https://sample.scanner.com",
}
_, err := scanner.AddRegistration(r)
_, err := scanner.AddRegistration(suite.Context(), r)
require.NoError(suite.T(), err, "add new registration")
}
@ -312,7 +327,7 @@ func (suite *TestReportConverterSuite) SetupSuite() {
func (suite *TestReportConverterSuite) TearDownTest() {
// No delete method defined in manager as no requirement,
// so, to clear env, call dao method here
scanner.DeleteRegistration(suite.registrationID)
scanner.DeleteRegistration(suite.Context(), suite.registrationID)
reports, err := suite.reportDao.List(orm.Context(), &q.Query{})
require.True(suite.T(), err == nil, "Failed to delete vulnerability records")
for _, report := range reports {

View File

@ -15,6 +15,8 @@
package scanner
import (
"context"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
@ -25,27 +27,27 @@ import (
type Manager interface {
// List returns a list of currently configured scanner registrations.
// Query parameters are optional
List(query *q.Query) ([]*scanner.Registration, error)
List(ctx context.Context, query *q.Query) ([]*scanner.Registration, error)
// Create creates a new scanner registration with the given data.
// Returns the scanner registration identifier.
Create(registration *scanner.Registration) (string, error)
Create(ctx context.Context, registration *scanner.Registration) (string, error)
// Get returns the details of the specified scanner registration.
Get(registrationUUID string) (*scanner.Registration, error)
Get(ctx context.Context, registrationUUID string) (*scanner.Registration, error)
// Update updates the specified scanner registration.
Update(registration *scanner.Registration) error
Update(ctx context.Context, registration *scanner.Registration) error
// Delete deletes the specified scanner registration.
Delete(registrationUUID string) error
Delete(ctx context.Context, registrationUUID string) error
// SetAsDefault marks the specified scanner registration as default.
// The implementation is supposed to unset any registration previously set as default.
SetAsDefault(registrationUUID string) error
SetAsDefault(ctx context.Context, registrationUUID string) error
// GetDefault returns the default scanner registration or `nil` if there are no registrations configured.
GetDefault() (*scanner.Registration, error)
GetDefault(ctx context.Context) (*scanner.Registration, error)
}
// basicManager is the default implementation of Manager
@ -57,7 +59,7 @@ func New() Manager {
}
// Create ...
func (bm *basicManager) Create(registration *scanner.Registration) (string, error) {
func (bm *basicManager) Create(ctx context.Context, registration *scanner.Registration) (string, error) {
if registration == nil {
return "", errors.New("nil registration to create")
}
@ -73,7 +75,7 @@ func (bm *basicManager) Create(registration *scanner.Registration) (string, erro
return "", errors.Wrap(err, "create registration")
}
if _, err := scanner.AddRegistration(registration); err != nil {
if _, err := scanner.AddRegistration(ctx, registration); err != nil {
return "", errors.Wrap(err, "dao: create registration")
}
@ -81,16 +83,16 @@ func (bm *basicManager) Create(registration *scanner.Registration) (string, erro
}
// Get ...
func (bm *basicManager) Get(registrationUUID string) (*scanner.Registration, error) {
func (bm *basicManager) Get(ctx context.Context, registrationUUID string) (*scanner.Registration, error) {
if len(registrationUUID) == 0 {
return nil, errors.New("empty uuid of registration")
}
return scanner.GetRegistration(registrationUUID)
return scanner.GetRegistration(ctx, registrationUUID)
}
// Update ...
func (bm *basicManager) Update(registration *scanner.Registration) error {
func (bm *basicManager) Update(ctx context.Context, registration *scanner.Registration) error {
if registration == nil {
return errors.New("nil registration to update")
}
@ -99,33 +101,33 @@ func (bm *basicManager) Update(registration *scanner.Registration) error {
return errors.Wrap(err, "update registration")
}
return scanner.UpdateRegistration(registration)
return scanner.UpdateRegistration(ctx, registration)
}
// Delete ...
func (bm *basicManager) Delete(registrationUUID string) error {
func (bm *basicManager) Delete(ctx context.Context, registrationUUID string) error {
if len(registrationUUID) == 0 {
return errors.New("empty UUID to delete")
}
return scanner.DeleteRegistration(registrationUUID)
return scanner.DeleteRegistration(ctx, registrationUUID)
}
// List ...
func (bm *basicManager) List(query *q.Query) ([]*scanner.Registration, error) {
return scanner.ListRegistrations(query)
func (bm *basicManager) List(ctx context.Context, query *q.Query) ([]*scanner.Registration, error) {
return scanner.ListRegistrations(ctx, query)
}
// SetAsDefault ...
func (bm *basicManager) SetAsDefault(registrationUUID string) error {
func (bm *basicManager) SetAsDefault(ctx context.Context, registrationUUID string) error {
if len(registrationUUID) == 0 {
return errors.New("empty UUID to set default")
}
return scanner.SetDefaultRegistration(registrationUUID)
return scanner.SetDefaultRegistration(ctx, registrationUUID)
}
// GetDefault ...
func (bm *basicManager) GetDefault() (*scanner.Registration, error) {
return scanner.GetDefaultRegistration()
func (bm *basicManager) GetDefault(ctx context.Context) (*scanner.Registration, error) {
return scanner.GetDefaultRegistration(ctx)
}

View File

@ -17,9 +17,9 @@ package scanner
import (
"testing"
"github.com/goharbor/harbor/src/common/dao"
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
htesting "github.com/goharbor/harbor/src/testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
@ -27,7 +27,7 @@ import (
// BasicManagerTestSuite tests the basic manager
type BasicManagerTestSuite struct {
suite.Suite
htesting.Suite
mgr Manager
sampleUUID string
@ -40,7 +40,7 @@ func TestBasicManager(t *testing.T) {
// SetupSuite prepares env for test suite
func (suite *BasicManagerTestSuite) SetupSuite() {
dao.PrepareTestForPostgresSQL()
suite.Suite.SetupSuite()
suite.mgr = New()
@ -50,14 +50,14 @@ func (suite *BasicManagerTestSuite) SetupSuite() {
URL: "https://sample.scanner.com",
}
uid, err := suite.mgr.Create(r)
uid, err := suite.mgr.Create(suite.Context(), r)
require.NoError(suite.T(), err)
suite.sampleUUID = uid
}
// TearDownSuite clears env for test suite
func (suite *BasicManagerTestSuite) TearDownSuite() {
err := suite.mgr.Delete(suite.sampleUUID)
err := suite.mgr.Delete(suite.Context(), suite.sampleUUID)
require.NoError(suite.T(), err, "delete registration")
}
@ -66,7 +66,7 @@ func (suite *BasicManagerTestSuite) TestList() {
m := make(map[string]interface{}, 1)
m["name"] = "forUT"
l, err := suite.mgr.List(&q.Query{
l, err := suite.mgr.List(suite.Context(), &q.Query{
PageNumber: 1,
PageSize: 10,
Keywords: m,
@ -78,7 +78,7 @@ func (suite *BasicManagerTestSuite) TestList() {
// TestGet tests get registration
func (suite *BasicManagerTestSuite) TestGet() {
r, err := suite.mgr.Get(suite.sampleUUID)
r, err := suite.mgr.Get(suite.Context(), suite.sampleUUID)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
assert.Equal(suite.T(), "forUT", r.Name)
@ -86,15 +86,15 @@ func (suite *BasicManagerTestSuite) TestGet() {
// TestUpdate tests update registration
func (suite *BasicManagerTestSuite) TestUpdate() {
r, err := suite.mgr.Get(suite.sampleUUID)
r, err := suite.mgr.Get(suite.Context(), suite.sampleUUID)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
r.URL = "https://updated.com"
err = suite.mgr.Update(r)
err = suite.mgr.Update(suite.Context(), r)
require.NoError(suite.T(), err)
r, err = suite.mgr.Get(suite.sampleUUID)
r, err = suite.mgr.Get(suite.Context(), suite.sampleUUID)
require.NoError(suite.T(), err)
require.NotNil(suite.T(), r)
assert.Equal(suite.T(), "https://updated.com", r.URL)
@ -102,10 +102,10 @@ func (suite *BasicManagerTestSuite) TestUpdate() {
// TestDefault tests get/set default registration
func (suite *BasicManagerTestSuite) TestDefault() {
err := suite.mgr.SetAsDefault(suite.sampleUUID)
err := suite.mgr.SetAsDefault(suite.Context(), suite.sampleUUID)
require.NoError(suite.T(), err)
dr, err := suite.mgr.GetDefault()
dr, err := suite.mgr.GetDefault(suite.Context())
require.NoError(suite.T(), err)
require.NotNil(suite.T(), dr)
assert.Equal(suite.T(), true, dr.IsDefault)

View File

@ -1,147 +0,0 @@
// Code generated by mockery v1.0.0. DO NOT EDIT.
package mocks
import (
q "github.com/goharbor/harbor/src/lib/q"
mock "github.com/stretchr/testify/mock"
scanner "github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
)
// Manager is an autogenerated mock type for the Manager type
type Manager struct {
mock.Mock
}
// Create provides a mock function with given fields: registration
func (_m *Manager) Create(registration *scanner.Registration) (string, error) {
ret := _m.Called(registration)
var r0 string
if rf, ok := ret.Get(0).(func(*scanner.Registration) string); ok {
r0 = rf(registration)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(*scanner.Registration) error); ok {
r1 = rf(registration)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Delete provides a mock function with given fields: registrationUUID
func (_m *Manager) Delete(registrationUUID string) error {
ret := _m.Called(registrationUUID)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(registrationUUID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Get provides a mock function with given fields: registrationUUID
func (_m *Manager) Get(registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(registrationUUID)
var r0 *scanner.Registration
if rf, ok := ret.Get(0).(func(string) *scanner.Registration); ok {
r0 = rf(registrationUUID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*scanner.Registration)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(registrationUUID)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// GetDefault provides a mock function with given fields:
func (_m *Manager) GetDefault() (*scanner.Registration, error) {
ret := _m.Called()
var r0 *scanner.Registration
if rf, ok := ret.Get(0).(func() *scanner.Registration); ok {
r0 = rf()
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*scanner.Registration)
}
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// List provides a mock function with given fields: query
func (_m *Manager) List(query *q.Query) ([]*scanner.Registration, error) {
ret := _m.Called(query)
var r0 []*scanner.Registration
if rf, ok := ret.Get(0).(func(*q.Query) []*scanner.Registration); ok {
r0 = rf(query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*scanner.Registration)
}
}
var r1 error
if rf, ok := ret.Get(1).(func(*q.Query) error); ok {
r1 = rf(query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// SetAsDefault provides a mock function with given fields: registrationUUID
func (_m *Manager) SetAsDefault(registrationUUID string) error {
ret := _m.Called(registrationUUID)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(registrationUUID)
} else {
r0 = ret.Error(0)
}
return r0
}
// Update provides a mock function with given fields: registration
func (_m *Manager) Update(registration *scanner.Registration) error {
ret := _m.Called(registration)
var r0 error
if rf, ok := ret.Get(0).(func(*scanner.Registration) error); ok {
r0 = rf(registration)
} else {
r0 = ret.Error(0)
}
return r0
}

View File

@ -260,7 +260,7 @@ func (s *scanAllAPI) requireScanEnabled(ctx context.Context) error {
Keywords: kws,
}
l, err := s.scannerCtl.ListRegistrations(query)
l, err := s.scannerCtl.ListRegistrations(ctx, query)
if err != nil {
return errors.Wrap(err, "check if scan is enabled")
}

View File

@ -20,20 +20,20 @@ type Controller struct {
mock.Mock
}
// CreateRegistration provides a mock function with given fields: registration
func (_m *Controller) CreateRegistration(registration *scanner.Registration) (string, error) {
ret := _m.Called(registration)
// CreateRegistration provides a mock function with given fields: ctx, registration
func (_m *Controller) CreateRegistration(ctx context.Context, registration *scanner.Registration) (string, error) {
ret := _m.Called(ctx, registration)
var r0 string
if rf, ok := ret.Get(0).(func(*scanner.Registration) string); ok {
r0 = rf(registration)
if rf, ok := ret.Get(0).(func(context.Context, *scanner.Registration) string); ok {
r0 = rf(ctx, registration)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(*scanner.Registration) error); ok {
r1 = rf(registration)
if rf, ok := ret.Get(1).(func(context.Context, *scanner.Registration) error); ok {
r1 = rf(ctx, registration)
} else {
r1 = ret.Error(1)
}
@ -41,13 +41,13 @@ func (_m *Controller) CreateRegistration(registration *scanner.Registration) (st
return r0, r1
}
// DeleteRegistration provides a mock function with given fields: registrationUUID
func (_m *Controller) DeleteRegistration(registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(registrationUUID)
// DeleteRegistration provides a mock function with given fields: ctx, registrationUUID
func (_m *Controller) DeleteRegistration(ctx context.Context, registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(ctx, registrationUUID)
var r0 *scanner.Registration
if rf, ok := ret.Get(0).(func(string) *scanner.Registration); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) *scanner.Registration); ok {
r0 = rf(ctx, registrationUUID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*scanner.Registration)
@ -55,8 +55,8 @@ func (_m *Controller) DeleteRegistration(registrationUUID string) (*scanner.Regi
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(registrationUUID)
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, registrationUUID)
} else {
r1 = ret.Error(1)
}
@ -64,13 +64,13 @@ func (_m *Controller) DeleteRegistration(registrationUUID string) (*scanner.Regi
return r0, r1
}
// GetMetadata provides a mock function with given fields: registrationUUID
func (_m *Controller) GetMetadata(registrationUUID string) (*v1.ScannerAdapterMetadata, error) {
ret := _m.Called(registrationUUID)
// GetMetadata provides a mock function with given fields: ctx, registrationUUID
func (_m *Controller) GetMetadata(ctx context.Context, registrationUUID string) (*v1.ScannerAdapterMetadata, error) {
ret := _m.Called(ctx, registrationUUID)
var r0 *v1.ScannerAdapterMetadata
if rf, ok := ret.Get(0).(func(string) *v1.ScannerAdapterMetadata); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) *v1.ScannerAdapterMetadata); ok {
r0 = rf(ctx, registrationUUID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.ScannerAdapterMetadata)
@ -78,8 +78,8 @@ func (_m *Controller) GetMetadata(registrationUUID string) (*v1.ScannerAdapterMe
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(registrationUUID)
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, registrationUUID)
} else {
r1 = ret.Error(1)
}
@ -87,13 +87,13 @@ func (_m *Controller) GetMetadata(registrationUUID string) (*v1.ScannerAdapterMe
return r0, r1
}
// GetRegistration provides a mock function with given fields: registrationUUID
func (_m *Controller) GetRegistration(registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(registrationUUID)
// GetRegistration provides a mock function with given fields: ctx, registrationUUID
func (_m *Controller) GetRegistration(ctx context.Context, registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(ctx, registrationUUID)
var r0 *scanner.Registration
if rf, ok := ret.Get(0).(func(string) *scanner.Registration); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) *scanner.Registration); ok {
r0 = rf(ctx, registrationUUID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*scanner.Registration)
@ -101,8 +101,8 @@ func (_m *Controller) GetRegistration(registrationUUID string) (*scanner.Registr
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(registrationUUID)
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, registrationUUID)
} else {
r1 = ret.Error(1)
}
@ -140,13 +140,13 @@ func (_m *Controller) GetRegistrationByProject(ctx context.Context, projectID in
return r0, r1
}
// ListRegistrations provides a mock function with given fields: query
func (_m *Controller) ListRegistrations(query *q.Query) ([]*scanner.Registration, error) {
ret := _m.Called(query)
// ListRegistrations provides a mock function with given fields: ctx, query
func (_m *Controller) ListRegistrations(ctx context.Context, query *q.Query) ([]*scanner.Registration, error) {
ret := _m.Called(ctx, query)
var r0 []*scanner.Registration
if rf, ok := ret.Get(0).(func(*q.Query) []*scanner.Registration); ok {
r0 = rf(query)
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) []*scanner.Registration); ok {
r0 = rf(ctx, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*scanner.Registration)
@ -154,8 +154,8 @@ func (_m *Controller) ListRegistrations(query *q.Query) ([]*scanner.Registration
}
var r1 error
if rf, ok := ret.Get(1).(func(*q.Query) error); ok {
r1 = rf(query)
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok {
r1 = rf(ctx, query)
} else {
r1 = ret.Error(1)
}
@ -163,13 +163,13 @@ func (_m *Controller) ListRegistrations(query *q.Query) ([]*scanner.Registration
return r0, r1
}
// Ping provides a mock function with given fields: registration
func (_m *Controller) Ping(registration *scanner.Registration) (*v1.ScannerAdapterMetadata, error) {
ret := _m.Called(registration)
// Ping provides a mock function with given fields: ctx, registration
func (_m *Controller) Ping(ctx context.Context, registration *scanner.Registration) (*v1.ScannerAdapterMetadata, error) {
ret := _m.Called(ctx, registration)
var r0 *v1.ScannerAdapterMetadata
if rf, ok := ret.Get(0).(func(*scanner.Registration) *v1.ScannerAdapterMetadata); ok {
r0 = rf(registration)
if rf, ok := ret.Get(0).(func(context.Context, *scanner.Registration) *v1.ScannerAdapterMetadata); ok {
r0 = rf(ctx, registration)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*v1.ScannerAdapterMetadata)
@ -177,8 +177,8 @@ func (_m *Controller) Ping(registration *scanner.Registration) (*v1.ScannerAdapt
}
var r1 error
if rf, ok := ret.Get(1).(func(*scanner.Registration) error); ok {
r1 = rf(registration)
if rf, ok := ret.Get(1).(func(context.Context, *scanner.Registration) error); ok {
r1 = rf(ctx, registration)
} else {
r1 = ret.Error(1)
}
@ -186,13 +186,13 @@ func (_m *Controller) Ping(registration *scanner.Registration) (*v1.ScannerAdapt
return r0, r1
}
// RegistrationExists provides a mock function with given fields: registrationUUID
func (_m *Controller) RegistrationExists(registrationUUID string) bool {
ret := _m.Called(registrationUUID)
// RegistrationExists provides a mock function with given fields: ctx, registrationUUID
func (_m *Controller) RegistrationExists(ctx context.Context, registrationUUID string) bool {
ret := _m.Called(ctx, registrationUUID)
var r0 bool
if rf, ok := ret.Get(0).(func(string) bool); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) bool); ok {
r0 = rf(ctx, registrationUUID)
} else {
r0 = ret.Get(0).(bool)
}
@ -200,13 +200,13 @@ func (_m *Controller) RegistrationExists(registrationUUID string) bool {
return r0
}
// SetDefaultRegistration provides a mock function with given fields: registrationUUID
func (_m *Controller) SetDefaultRegistration(registrationUUID string) error {
ret := _m.Called(registrationUUID)
// SetDefaultRegistration provides a mock function with given fields: ctx, registrationUUID
func (_m *Controller) SetDefaultRegistration(ctx context.Context, registrationUUID string) error {
ret := _m.Called(ctx, registrationUUID)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, registrationUUID)
} else {
r0 = ret.Error(0)
}
@ -228,13 +228,13 @@ func (_m *Controller) SetRegistrationByProject(ctx context.Context, projectID in
return r0
}
// UpdateRegistration provides a mock function with given fields: registration
func (_m *Controller) UpdateRegistration(registration *scanner.Registration) error {
ret := _m.Called(registration)
// UpdateRegistration provides a mock function with given fields: ctx, registration
func (_m *Controller) UpdateRegistration(ctx context.Context, registration *scanner.Registration) error {
ret := _m.Called(ctx, registration)
var r0 error
if rf, ok := ret.Get(0).(func(*scanner.Registration) error); ok {
r0 = rf(registration)
if rf, ok := ret.Get(0).(func(context.Context, *scanner.Registration) error); ok {
r0 = rf(ctx, registration)
} else {
r0 = ret.Error(0)
}

View File

@ -3,6 +3,8 @@
package scanner
import (
context "context"
q "github.com/goharbor/harbor/src/lib/q"
mock "github.com/stretchr/testify/mock"
@ -14,20 +16,20 @@ type Manager struct {
mock.Mock
}
// Create provides a mock function with given fields: registration
func (_m *Manager) Create(registration *scanner.Registration) (string, error) {
ret := _m.Called(registration)
// Create provides a mock function with given fields: ctx, registration
func (_m *Manager) Create(ctx context.Context, registration *scanner.Registration) (string, error) {
ret := _m.Called(ctx, registration)
var r0 string
if rf, ok := ret.Get(0).(func(*scanner.Registration) string); ok {
r0 = rf(registration)
if rf, ok := ret.Get(0).(func(context.Context, *scanner.Registration) string); ok {
r0 = rf(ctx, registration)
} else {
r0 = ret.Get(0).(string)
}
var r1 error
if rf, ok := ret.Get(1).(func(*scanner.Registration) error); ok {
r1 = rf(registration)
if rf, ok := ret.Get(1).(func(context.Context, *scanner.Registration) error); ok {
r1 = rf(ctx, registration)
} else {
r1 = ret.Error(1)
}
@ -35,13 +37,13 @@ func (_m *Manager) Create(registration *scanner.Registration) (string, error) {
return r0, r1
}
// Delete provides a mock function with given fields: registrationUUID
func (_m *Manager) Delete(registrationUUID string) error {
ret := _m.Called(registrationUUID)
// Delete provides a mock function with given fields: ctx, registrationUUID
func (_m *Manager) Delete(ctx context.Context, registrationUUID string) error {
ret := _m.Called(ctx, registrationUUID)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, registrationUUID)
} else {
r0 = ret.Error(0)
}
@ -49,13 +51,13 @@ func (_m *Manager) Delete(registrationUUID string) error {
return r0
}
// Get provides a mock function with given fields: registrationUUID
func (_m *Manager) Get(registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(registrationUUID)
// Get provides a mock function with given fields: ctx, registrationUUID
func (_m *Manager) Get(ctx context.Context, registrationUUID string) (*scanner.Registration, error) {
ret := _m.Called(ctx, registrationUUID)
var r0 *scanner.Registration
if rf, ok := ret.Get(0).(func(string) *scanner.Registration); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) *scanner.Registration); ok {
r0 = rf(ctx, registrationUUID)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*scanner.Registration)
@ -63,8 +65,8 @@ func (_m *Manager) Get(registrationUUID string) (*scanner.Registration, error) {
}
var r1 error
if rf, ok := ret.Get(1).(func(string) error); ok {
r1 = rf(registrationUUID)
if rf, ok := ret.Get(1).(func(context.Context, string) error); ok {
r1 = rf(ctx, registrationUUID)
} else {
r1 = ret.Error(1)
}
@ -72,13 +74,13 @@ func (_m *Manager) Get(registrationUUID string) (*scanner.Registration, error) {
return r0, r1
}
// GetDefault provides a mock function with given fields:
func (_m *Manager) GetDefault() (*scanner.Registration, error) {
ret := _m.Called()
// GetDefault provides a mock function with given fields: ctx
func (_m *Manager) GetDefault(ctx context.Context) (*scanner.Registration, error) {
ret := _m.Called(ctx)
var r0 *scanner.Registration
if rf, ok := ret.Get(0).(func() *scanner.Registration); ok {
r0 = rf()
if rf, ok := ret.Get(0).(func(context.Context) *scanner.Registration); ok {
r0 = rf(ctx)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*scanner.Registration)
@ -86,8 +88,8 @@ func (_m *Manager) GetDefault() (*scanner.Registration, error) {
}
var r1 error
if rf, ok := ret.Get(1).(func() error); ok {
r1 = rf()
if rf, ok := ret.Get(1).(func(context.Context) error); ok {
r1 = rf(ctx)
} else {
r1 = ret.Error(1)
}
@ -95,13 +97,13 @@ func (_m *Manager) GetDefault() (*scanner.Registration, error) {
return r0, r1
}
// List provides a mock function with given fields: query
func (_m *Manager) List(query *q.Query) ([]*scanner.Registration, error) {
ret := _m.Called(query)
// List provides a mock function with given fields: ctx, query
func (_m *Manager) List(ctx context.Context, query *q.Query) ([]*scanner.Registration, error) {
ret := _m.Called(ctx, query)
var r0 []*scanner.Registration
if rf, ok := ret.Get(0).(func(*q.Query) []*scanner.Registration); ok {
r0 = rf(query)
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) []*scanner.Registration); ok {
r0 = rf(ctx, query)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*scanner.Registration)
@ -109,8 +111,8 @@ func (_m *Manager) List(query *q.Query) ([]*scanner.Registration, error) {
}
var r1 error
if rf, ok := ret.Get(1).(func(*q.Query) error); ok {
r1 = rf(query)
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok {
r1 = rf(ctx, query)
} else {
r1 = ret.Error(1)
}
@ -118,13 +120,13 @@ func (_m *Manager) List(query *q.Query) ([]*scanner.Registration, error) {
return r0, r1
}
// SetAsDefault provides a mock function with given fields: registrationUUID
func (_m *Manager) SetAsDefault(registrationUUID string) error {
ret := _m.Called(registrationUUID)
// SetAsDefault provides a mock function with given fields: ctx, registrationUUID
func (_m *Manager) SetAsDefault(ctx context.Context, registrationUUID string) error {
ret := _m.Called(ctx, registrationUUID)
var r0 error
if rf, ok := ret.Get(0).(func(string) error); ok {
r0 = rf(registrationUUID)
if rf, ok := ret.Get(0).(func(context.Context, string) error); ok {
r0 = rf(ctx, registrationUUID)
} else {
r0 = ret.Error(0)
}
@ -132,13 +134,13 @@ func (_m *Manager) SetAsDefault(registrationUUID string) error {
return r0
}
// Update provides a mock function with given fields: registration
func (_m *Manager) Update(registration *scanner.Registration) error {
ret := _m.Called(registration)
// Update provides a mock function with given fields: ctx, registration
func (_m *Manager) Update(ctx context.Context, registration *scanner.Registration) error {
ret := _m.Called(ctx, registration)
var r0 error
if rf, ok := ret.Get(0).(func(*scanner.Registration) error); ok {
r0 = rf(registration)
if rf, ok := ret.Get(0).(func(context.Context, *scanner.Registration) error); ok {
r0 = rf(ctx, registration)
} else {
r0 = ret.Error(0)
}