From ace21240a46ae7c7038f1706e930af036228fc7f Mon Sep 17 00:00:00 2001 From: chlins Date: Fri, 3 Jul 2020 11:59:32 +0800 Subject: [PATCH] fix: add count method of policy manager to replace list method return wrong counts Signed-off-by: chlins --- src/controller/p2p/preheat/enforcer.go | 2 +- src/controller/p2p/preheat/enforcer_test.go | 2 +- src/pkg/p2p/preheat/dao/policy/dao.go | 30 ++++-- src/pkg/p2p/preheat/dao/policy/dao_test.go | 13 ++- src/pkg/p2p/preheat/policy/manager.go | 15 ++- src/pkg/p2p/preheat/policy/manager_test.go | 19 +++- src/testing/pkg/p2p/preheat/policy/manager.go | 95 ++++++++++--------- 7 files changed, 110 insertions(+), 66 deletions(-) diff --git a/src/controller/p2p/preheat/enforcer.go b/src/controller/p2p/preheat/enforcer.go index c9652804b..ed1d10715 100644 --- a/src/controller/p2p/preheat/enforcer.go +++ b/src/controller/p2p/preheat/enforcer.go @@ -217,7 +217,7 @@ func (de *defaultEnforcer) PreheatArtifact(ctx context.Context, art *artifact.Ar } // Find all the policies that match the given artifact - _, l, err := de.policyMgr.ListPoliciesByProject(ctx, art.ProjectID, nil) + l, err := de.policyMgr.ListPoliciesByProject(ctx, art.ProjectID, nil) if err != nil { return nil, enforceErrorExt(err, art) } diff --git a/src/controller/p2p/preheat/enforcer_test.go b/src/controller/p2p/preheat/enforcer_test.go index 4ade55e3e..f6ee466ef 100644 --- a/src/controller/p2p/preheat/enforcer_test.go +++ b/src/controller/p2p/preheat/enforcer_test.go @@ -63,7 +63,7 @@ func (suite *EnforcerTestSuite) SetupSuite() { context.TODO(), mock.AnythingOfType("int64"), mock.AnythingOfType("*q.Query"), - ).Return((int64)(2), fakePolicies, nil) + ).Return(fakePolicies, nil) fakeExecManager := &task.FakeExecutionManager{} fakeExecManager.On("Create", diff --git a/src/pkg/p2p/preheat/dao/policy/dao.go b/src/pkg/p2p/preheat/dao/policy/dao.go index cf3a13bb2..99fbe8b50 100644 --- a/src/pkg/p2p/preheat/dao/policy/dao.go +++ b/src/pkg/p2p/preheat/dao/policy/dao.go @@ -26,6 +26,8 @@ import ( // DAO is the data access object for policy. type DAO interface { + // Count returns the total count of policies according to the query + Count(ctx context.Context, query *q.Query) (total int64, err error) // Create the policy schema Create(ctx context.Context, schema *policy.Schema) (id int64, err error) // Update the policy schema, Only the properties specified by "props" will be updated if it is set @@ -35,7 +37,7 @@ type DAO interface { // Delete the policy schema by id Delete(ctx context.Context, id int64) (err error) // List policy schemas by query - List(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) + List(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) } // New returns an instance of the default DAO. @@ -45,6 +47,23 @@ func New() DAO { type dao struct{} +// Count returns the total count of policies according to the query +func (d *dao) Count(ctx context.Context, query *q.Query) (total int64, err error) { + if query != nil { + // ignore the page number and size + query = &q.Query{ + Keywords: query.Keywords, + } + } + + qs, err := orm.QuerySetter(ctx, &policy.Schema{}, query) + if err != nil { + return 0, err + } + + return qs.Count() +} + // Create a policy schema. func (d *dao) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) { var ormer beego_orm.Ormer @@ -126,22 +145,17 @@ func (d *dao) Delete(ctx context.Context, id int64) (err error) { } // List policies by query. -func (d *dao) List(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) { +func (d *dao) List(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) { var qs beego_orm.QuerySeter qs, err = orm.QuerySetter(ctx, &policy.Schema{}, query) if err != nil { return } - total, err = qs.Count() - if err != nil { - return - } - qs = qs.OrderBy("UpdatedTime", "ID") if _, err = qs.All(&schemas); err != nil { return } - return total, schemas, nil + return schemas, nil } diff --git a/src/pkg/p2p/preheat/dao/policy/dao_test.go b/src/pkg/p2p/preheat/dao/policy/dao_test.go index cfbf65c2f..1c3ba84a8 100644 --- a/src/pkg/p2p/preheat/dao/policy/dao_test.go +++ b/src/pkg/p2p/preheat/dao/policy/dao_test.go @@ -71,6 +71,13 @@ func (d *daoTestSuite) TearDownSuite() { d.Require().Nil(err) } +// TestCount tests count total +func (d *daoTestSuite) TestCount() { + total, err := d.dao.Count(d.ctx, nil) + d.Require().Nil(err) + d.Equal(int64(1), total) +} + // TestCreate tests create a policy schema. func (d *daoTestSuite) TestCreate() { // create duplicate policy should return error @@ -139,9 +146,8 @@ func (d *daoTestSuite) TestList() { d.Require().Nil(err) }() - total, policies, err := d.dao.List(d.ctx, &q.Query{}) + policies, err := d.dao.List(d.ctx, &q.Query{}) d.Require().Nil(err) - d.Equal(int64(2), total) d.Len(policies, 2, "list all policy schemas") // list policy filter by project @@ -150,9 +156,8 @@ func (d *daoTestSuite) TestList() { "project_id": 1, }, } - total, policies, err = d.dao.List(d.ctx, query) + policies, err = d.dao.List(d.ctx, query) d.Require().Nil(err) - d.Equal(int64(1), total) d.Len(policies, 1, "list policy schemas by project") d.Equal(d.defaultPolicy.Name, policies[0].Name) } diff --git a/src/pkg/p2p/preheat/policy/manager.go b/src/pkg/p2p/preheat/policy/manager.go index e9754eea9..1f6170a3f 100644 --- a/src/pkg/p2p/preheat/policy/manager.go +++ b/src/pkg/p2p/preheat/policy/manager.go @@ -27,6 +27,8 @@ var Mgr = New() // Manager manages the policy type Manager interface { + // Count returns the total count of policies according to the query + Count(ctx context.Context, query *q.Query) (total int64, err error) // Create the policy schema Create(ctx context.Context, schema *policy.Schema) (id int64, err error) // Update the policy schema, Only the properties specified by "props" will be updated if it is set @@ -36,9 +38,9 @@ type Manager interface { // Delete the policy schema by id Delete(ctx context.Context, id int64) (err error) // List policy schemas by query - ListPolicies(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) + ListPolicies(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) // list policy schema under project - ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (total int64, schemas []*policy.Schema, err error) + ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (schemas []*policy.Schema, err error) } type manager struct { @@ -52,6 +54,11 @@ func New() Manager { } } +// Count returns the total count of policies according to the query +func (m *manager) Count(ctx context.Context, query *q.Query) (total int64, err error) { + return m.dao.Count(ctx, query) +} + // Create the policy schema func (m *manager) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) { return m.dao.Create(ctx, schema) @@ -73,12 +80,12 @@ func (m *manager) Delete(ctx context.Context, id int64) (err error) { } // List policy schemas by query -func (m *manager) ListPolicies(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) { +func (m *manager) ListPolicies(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) { return m.dao.List(ctx, query) } // list policy schema under project -func (m *manager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (total int64, schemas []*policy.Schema, err error) { +func (m *manager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (schemas []*policy.Schema, err error) { if query == nil { query = &q.Query{} } diff --git a/src/pkg/p2p/preheat/policy/manager_test.go b/src/pkg/p2p/preheat/policy/manager_test.go index 056e27122..a89915101 100644 --- a/src/pkg/p2p/preheat/policy/manager_test.go +++ b/src/pkg/p2p/preheat/policy/manager_test.go @@ -28,6 +28,10 @@ type fakeDao struct { mock.Mock } +func (f *fakeDao) Count(ctx context.Context, q *q.Query) (int64, error) { + args := f.Called() + return int64(args.Int(0)), args.Error(1) +} func (f *fakeDao) Create(ctx context.Context, schema *policy.Schema) (int64, error) { args := f.Called() return int64(args.Int(0)), args.Error(1) @@ -48,13 +52,13 @@ func (f *fakeDao) Delete(ctx context.Context, id int64) error { args := f.Called() return args.Error(0) } -func (f *fakeDao) List(ctx context.Context, query *q.Query) (int64, []*policy.Schema, error) { +func (f *fakeDao) List(ctx context.Context, query *q.Query) ([]*policy.Schema, error) { args := f.Called() var schemas []*policy.Schema if args.Get(0) != nil { schemas = args.Get(0).([]*policy.Schema) } - return 0, schemas, args.Error(1) + return schemas, args.Error(1) } type managerTestSuite struct { @@ -80,6 +84,13 @@ func (m *managerTestSuite) TearDownSuite() { m.mgr = nil } +// TestCount tests Count method. +func (m *managerTestSuite) TestCount() { + m.dao.On("Count").Return(1, nil) + _, err := m.mgr.Count(nil, nil) + m.Require().Nil(err) +} + // TestCreate tests Create method. func (m *managerTestSuite) TestCreate() { m.dao.On("Create").Return(1, nil) @@ -111,13 +122,13 @@ func (m *managerTestSuite) TestDelete() { // TestListPolicies tests ListPolicies method. func (m *managerTestSuite) TestListPolicies() { m.dao.On("List").Return(nil, nil) - _, _, err := m.mgr.ListPolicies(nil, nil) + _, err := m.mgr.ListPolicies(nil, nil) m.Require().Nil(err) } // TestListPoliciesByProject tests ListPoliciesByProject method. func (m *managerTestSuite) TestListPoliciesByProject() { m.dao.On("List").Return(nil, nil) - _, _, err := m.mgr.ListPoliciesByProject(nil, 1, nil) + _, err := m.mgr.ListPoliciesByProject(nil, 1, nil) m.Require().Nil(err) } diff --git a/src/testing/pkg/p2p/preheat/policy/manager.go b/src/testing/pkg/p2p/preheat/policy/manager.go index 672f17a75..e9d01f05c 100644 --- a/src/testing/pkg/p2p/preheat/policy/manager.go +++ b/src/testing/pkg/p2p/preheat/policy/manager.go @@ -5,7 +5,7 @@ package policy import ( context "context" - policy "github.com/goharbor/harbor/src/pkg/p2p/preheat/models/policy" + modelspolicy "github.com/goharbor/harbor/src/pkg/p2p/preheat/models/policy" mock "github.com/stretchr/testify/mock" q "github.com/goharbor/harbor/src/lib/q" @@ -16,19 +16,40 @@ type FakeManager struct { mock.Mock } +// Count provides a mock function with given fields: ctx, query +func (_m *FakeManager) Count(ctx context.Context, query *q.Query) (int64, error) { + ret := _m.Called(ctx, query) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, *q.Query) int64); ok { + r0 = rf(ctx, query) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok { + r1 = rf(ctx, query) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + // Create provides a mock function with given fields: ctx, schema -func (_m *FakeManager) Create(ctx context.Context, schema *policy.Schema) (int64, error) { +func (_m *FakeManager) Create(ctx context.Context, schema *modelspolicy.Schema) (int64, error) { ret := _m.Called(ctx, schema) var r0 int64 - if rf, ok := ret.Get(0).(func(context.Context, *policy.Schema) int64); ok { + if rf, ok := ret.Get(0).(func(context.Context, *modelspolicy.Schema) int64); ok { r0 = rf(ctx, schema) } else { r0 = ret.Get(0).(int64) } var r1 error - if rf, ok := ret.Get(1).(func(context.Context, *policy.Schema) error); ok { + if rf, ok := ret.Get(1).(func(context.Context, *modelspolicy.Schema) error); ok { r1 = rf(ctx, schema) } else { r1 = ret.Error(1) @@ -52,15 +73,15 @@ func (_m *FakeManager) Delete(ctx context.Context, id int64) error { } // Get provides a mock function with given fields: ctx, id -func (_m *FakeManager) Get(ctx context.Context, id int64) (*policy.Schema, error) { +func (_m *FakeManager) Get(ctx context.Context, id int64) (*modelspolicy.Schema, error) { ret := _m.Called(ctx, id) - var r0 *policy.Schema - if rf, ok := ret.Get(0).(func(context.Context, int64) *policy.Schema); ok { + var r0 *modelspolicy.Schema + if rf, ok := ret.Get(0).(func(context.Context, int64) *modelspolicy.Schema); ok { r0 = rf(ctx, id) } else { if ret.Get(0) != nil { - r0 = ret.Get(0).(*policy.Schema) + r0 = ret.Get(0).(*modelspolicy.Schema) } } @@ -75,67 +96,53 @@ func (_m *FakeManager) Get(ctx context.Context, id int64) (*policy.Schema, error } // ListPolicies provides a mock function with given fields: ctx, query -func (_m *FakeManager) ListPolicies(ctx context.Context, query *q.Query) (int64, []*policy.Schema, error) { +func (_m *FakeManager) ListPolicies(ctx context.Context, query *q.Query) ([]*modelspolicy.Schema, error) { ret := _m.Called(ctx, query) - var r0 int64 - if rf, ok := ret.Get(0).(func(context.Context, *q.Query) int64); ok { + var r0 []*modelspolicy.Schema + if rf, ok := ret.Get(0).(func(context.Context, *q.Query) []*modelspolicy.Schema); ok { r0 = rf(ctx, query) } else { - r0 = ret.Get(0).(int64) - } - - var r1 []*policy.Schema - if rf, ok := ret.Get(1).(func(context.Context, *q.Query) []*policy.Schema); ok { - r1 = rf(ctx, query) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).([]*policy.Schema) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*modelspolicy.Schema) } } - var r2 error - if rf, ok := ret.Get(2).(func(context.Context, *q.Query) error); ok { - r2 = rf(ctx, query) + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok { + r1 = rf(ctx, query) } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // ListPoliciesByProject provides a mock function with given fields: ctx, project, query -func (_m *FakeManager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (int64, []*policy.Schema, error) { +func (_m *FakeManager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) ([]*modelspolicy.Schema, error) { ret := _m.Called(ctx, project, query) - var r0 int64 - if rf, ok := ret.Get(0).(func(context.Context, int64, *q.Query) int64); ok { + var r0 []*modelspolicy.Schema + if rf, ok := ret.Get(0).(func(context.Context, int64, *q.Query) []*modelspolicy.Schema); ok { r0 = rf(ctx, project, query) } else { - r0 = ret.Get(0).(int64) - } - - var r1 []*policy.Schema - if rf, ok := ret.Get(1).(func(context.Context, int64, *q.Query) []*policy.Schema); ok { - r1 = rf(ctx, project, query) - } else { - if ret.Get(1) != nil { - r1 = ret.Get(1).([]*policy.Schema) + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*modelspolicy.Schema) } } - var r2 error - if rf, ok := ret.Get(2).(func(context.Context, int64, *q.Query) error); ok { - r2 = rf(ctx, project, query) + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int64, *q.Query) error); ok { + r1 = rf(ctx, project, query) } else { - r2 = ret.Error(2) + r1 = ret.Error(1) } - return r0, r1, r2 + return r0, r1 } // Update provides a mock function with given fields: ctx, schema, props -func (_m *FakeManager) Update(ctx context.Context, schema *policy.Schema, props ...string) error { +func (_m *FakeManager) Update(ctx context.Context, schema *modelspolicy.Schema, props ...string) error { _va := make([]interface{}, len(props)) for _i := range props { _va[_i] = props[_i] @@ -146,7 +153,7 @@ func (_m *FakeManager) Update(ctx context.Context, schema *policy.Schema, props ret := _m.Called(_ca...) var r0 error - if rf, ok := ret.Get(0).(func(context.Context, *policy.Schema, ...string) error); ok { + if rf, ok := ret.Get(0).(func(context.Context, *modelspolicy.Schema, ...string) error); ok { r0 = rf(ctx, schema, props...) } else { r0 = ret.Error(0)