fix: add count method of policy manager to replace list method return wrong counts

Signed-off-by: chlins <chlins.zhang@gmail.com>
This commit is contained in:
chlins 2020-07-03 11:59:32 +08:00
parent 47e731d885
commit ace21240a4
7 changed files with 110 additions and 66 deletions

View File

@ -217,7 +217,7 @@ func (de *defaultEnforcer) PreheatArtifact(ctx context.Context, art *artifact.Ar
}
// Find all the policies that match the given artifact
_, l, err := de.policyMgr.ListPoliciesByProject(ctx, art.ProjectID, nil)
l, err := de.policyMgr.ListPoliciesByProject(ctx, art.ProjectID, nil)
if err != nil {
return nil, enforceErrorExt(err, art)
}

View File

@ -63,7 +63,7 @@ func (suite *EnforcerTestSuite) SetupSuite() {
context.TODO(),
mock.AnythingOfType("int64"),
mock.AnythingOfType("*q.Query"),
).Return((int64)(2), fakePolicies, nil)
).Return(fakePolicies, nil)
fakeExecManager := &task.FakeExecutionManager{}
fakeExecManager.On("Create",

View File

@ -26,6 +26,8 @@ import (
// DAO is the data access object for policy.
type DAO interface {
// Count returns the total count of policies according to the query
Count(ctx context.Context, query *q.Query) (total int64, err error)
// Create the policy schema
Create(ctx context.Context, schema *policy.Schema) (id int64, err error)
// Update the policy schema, Only the properties specified by "props" will be updated if it is set
@ -35,7 +37,7 @@ type DAO interface {
// Delete the policy schema by id
Delete(ctx context.Context, id int64) (err error)
// List policy schemas by query
List(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error)
List(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error)
}
// New returns an instance of the default DAO.
@ -45,6 +47,23 @@ func New() DAO {
type dao struct{}
// Count returns the total count of policies according to the query
func (d *dao) Count(ctx context.Context, query *q.Query) (total int64, err error) {
if query != nil {
// ignore the page number and size
query = &q.Query{
Keywords: query.Keywords,
}
}
qs, err := orm.QuerySetter(ctx, &policy.Schema{}, query)
if err != nil {
return 0, err
}
return qs.Count()
}
// Create a policy schema.
func (d *dao) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) {
var ormer beego_orm.Ormer
@ -126,22 +145,17 @@ func (d *dao) Delete(ctx context.Context, id int64) (err error) {
}
// List policies by query.
func (d *dao) List(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) {
func (d *dao) List(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) {
var qs beego_orm.QuerySeter
qs, err = orm.QuerySetter(ctx, &policy.Schema{}, query)
if err != nil {
return
}
total, err = qs.Count()
if err != nil {
return
}
qs = qs.OrderBy("UpdatedTime", "ID")
if _, err = qs.All(&schemas); err != nil {
return
}
return total, schemas, nil
return schemas, nil
}

View File

@ -71,6 +71,13 @@ func (d *daoTestSuite) TearDownSuite() {
d.Require().Nil(err)
}
// TestCount tests count total
func (d *daoTestSuite) TestCount() {
total, err := d.dao.Count(d.ctx, nil)
d.Require().Nil(err)
d.Equal(int64(1), total)
}
// TestCreate tests create a policy schema.
func (d *daoTestSuite) TestCreate() {
// create duplicate policy should return error
@ -139,9 +146,8 @@ func (d *daoTestSuite) TestList() {
d.Require().Nil(err)
}()
total, policies, err := d.dao.List(d.ctx, &q.Query{})
policies, err := d.dao.List(d.ctx, &q.Query{})
d.Require().Nil(err)
d.Equal(int64(2), total)
d.Len(policies, 2, "list all policy schemas")
// list policy filter by project
@ -150,9 +156,8 @@ func (d *daoTestSuite) TestList() {
"project_id": 1,
},
}
total, policies, err = d.dao.List(d.ctx, query)
policies, err = d.dao.List(d.ctx, query)
d.Require().Nil(err)
d.Equal(int64(1), total)
d.Len(policies, 1, "list policy schemas by project")
d.Equal(d.defaultPolicy.Name, policies[0].Name)
}

View File

@ -27,6 +27,8 @@ var Mgr = New()
// Manager manages the policy
type Manager interface {
// Count returns the total count of policies according to the query
Count(ctx context.Context, query *q.Query) (total int64, err error)
// Create the policy schema
Create(ctx context.Context, schema *policy.Schema) (id int64, err error)
// Update the policy schema, Only the properties specified by "props" will be updated if it is set
@ -36,9 +38,9 @@ type Manager interface {
// Delete the policy schema by id
Delete(ctx context.Context, id int64) (err error)
// List policy schemas by query
ListPolicies(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error)
ListPolicies(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error)
// list policy schema under project
ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (total int64, schemas []*policy.Schema, err error)
ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (schemas []*policy.Schema, err error)
}
type manager struct {
@ -52,6 +54,11 @@ func New() Manager {
}
}
// Count returns the total count of policies according to the query
func (m *manager) Count(ctx context.Context, query *q.Query) (total int64, err error) {
return m.dao.Count(ctx, query)
}
// Create the policy schema
func (m *manager) Create(ctx context.Context, schema *policy.Schema) (id int64, err error) {
return m.dao.Create(ctx, schema)
@ -73,12 +80,12 @@ func (m *manager) Delete(ctx context.Context, id int64) (err error) {
}
// List policy schemas by query
func (m *manager) ListPolicies(ctx context.Context, query *q.Query) (total int64, schemas []*policy.Schema, err error) {
func (m *manager) ListPolicies(ctx context.Context, query *q.Query) (schemas []*policy.Schema, err error) {
return m.dao.List(ctx, query)
}
// list policy schema under project
func (m *manager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (total int64, schemas []*policy.Schema, err error) {
func (m *manager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (schemas []*policy.Schema, err error) {
if query == nil {
query = &q.Query{}
}

View File

@ -28,6 +28,10 @@ type fakeDao struct {
mock.Mock
}
func (f *fakeDao) Count(ctx context.Context, q *q.Query) (int64, error) {
args := f.Called()
return int64(args.Int(0)), args.Error(1)
}
func (f *fakeDao) Create(ctx context.Context, schema *policy.Schema) (int64, error) {
args := f.Called()
return int64(args.Int(0)), args.Error(1)
@ -48,13 +52,13 @@ func (f *fakeDao) Delete(ctx context.Context, id int64) error {
args := f.Called()
return args.Error(0)
}
func (f *fakeDao) List(ctx context.Context, query *q.Query) (int64, []*policy.Schema, error) {
func (f *fakeDao) List(ctx context.Context, query *q.Query) ([]*policy.Schema, error) {
args := f.Called()
var schemas []*policy.Schema
if args.Get(0) != nil {
schemas = args.Get(0).([]*policy.Schema)
}
return 0, schemas, args.Error(1)
return schemas, args.Error(1)
}
type managerTestSuite struct {
@ -80,6 +84,13 @@ func (m *managerTestSuite) TearDownSuite() {
m.mgr = nil
}
// TestCount tests Count method.
func (m *managerTestSuite) TestCount() {
m.dao.On("Count").Return(1, nil)
_, err := m.mgr.Count(nil, nil)
m.Require().Nil(err)
}
// TestCreate tests Create method.
func (m *managerTestSuite) TestCreate() {
m.dao.On("Create").Return(1, nil)
@ -111,13 +122,13 @@ func (m *managerTestSuite) TestDelete() {
// TestListPolicies tests ListPolicies method.
func (m *managerTestSuite) TestListPolicies() {
m.dao.On("List").Return(nil, nil)
_, _, err := m.mgr.ListPolicies(nil, nil)
_, err := m.mgr.ListPolicies(nil, nil)
m.Require().Nil(err)
}
// TestListPoliciesByProject tests ListPoliciesByProject method.
func (m *managerTestSuite) TestListPoliciesByProject() {
m.dao.On("List").Return(nil, nil)
_, _, err := m.mgr.ListPoliciesByProject(nil, 1, nil)
_, err := m.mgr.ListPoliciesByProject(nil, 1, nil)
m.Require().Nil(err)
}

View File

@ -5,7 +5,7 @@ package policy
import (
context "context"
policy "github.com/goharbor/harbor/src/pkg/p2p/preheat/models/policy"
modelspolicy "github.com/goharbor/harbor/src/pkg/p2p/preheat/models/policy"
mock "github.com/stretchr/testify/mock"
q "github.com/goharbor/harbor/src/lib/q"
@ -16,19 +16,40 @@ type FakeManager struct {
mock.Mock
}
// Count provides a mock function with given fields: ctx, query
func (_m *FakeManager) Count(ctx context.Context, query *q.Query) (int64, error) {
ret := _m.Called(ctx, query)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) int64); ok {
r0 = rf(ctx, query)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok {
r1 = rf(ctx, query)
} else {
r1 = ret.Error(1)
}
return r0, r1
}
// Create provides a mock function with given fields: ctx, schema
func (_m *FakeManager) Create(ctx context.Context, schema *policy.Schema) (int64, error) {
func (_m *FakeManager) Create(ctx context.Context, schema *modelspolicy.Schema) (int64, error) {
ret := _m.Called(ctx, schema)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, *policy.Schema) int64); ok {
if rf, ok := ret.Get(0).(func(context.Context, *modelspolicy.Schema) int64); ok {
r0 = rf(ctx, schema)
} else {
r0 = ret.Get(0).(int64)
}
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *policy.Schema) error); ok {
if rf, ok := ret.Get(1).(func(context.Context, *modelspolicy.Schema) error); ok {
r1 = rf(ctx, schema)
} else {
r1 = ret.Error(1)
@ -52,15 +73,15 @@ func (_m *FakeManager) Delete(ctx context.Context, id int64) error {
}
// Get provides a mock function with given fields: ctx, id
func (_m *FakeManager) Get(ctx context.Context, id int64) (*policy.Schema, error) {
func (_m *FakeManager) Get(ctx context.Context, id int64) (*modelspolicy.Schema, error) {
ret := _m.Called(ctx, id)
var r0 *policy.Schema
if rf, ok := ret.Get(0).(func(context.Context, int64) *policy.Schema); ok {
var r0 *modelspolicy.Schema
if rf, ok := ret.Get(0).(func(context.Context, int64) *modelspolicy.Schema); ok {
r0 = rf(ctx, id)
} else {
if ret.Get(0) != nil {
r0 = ret.Get(0).(*policy.Schema)
r0 = ret.Get(0).(*modelspolicy.Schema)
}
}
@ -75,67 +96,53 @@ func (_m *FakeManager) Get(ctx context.Context, id int64) (*policy.Schema, error
}
// ListPolicies provides a mock function with given fields: ctx, query
func (_m *FakeManager) ListPolicies(ctx context.Context, query *q.Query) (int64, []*policy.Schema, error) {
func (_m *FakeManager) ListPolicies(ctx context.Context, query *q.Query) ([]*modelspolicy.Schema, error) {
ret := _m.Called(ctx, query)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) int64); ok {
var r0 []*modelspolicy.Schema
if rf, ok := ret.Get(0).(func(context.Context, *q.Query) []*modelspolicy.Schema); ok {
r0 = rf(ctx, query)
} else {
r0 = ret.Get(0).(int64)
}
var r1 []*policy.Schema
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) []*policy.Schema); ok {
r1 = rf(ctx, query)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]*policy.Schema)
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*modelspolicy.Schema)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, *q.Query) error); ok {
r2 = rf(ctx, query)
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok {
r1 = rf(ctx, query)
} else {
r2 = ret.Error(2)
r1 = ret.Error(1)
}
return r0, r1, r2
return r0, r1
}
// ListPoliciesByProject provides a mock function with given fields: ctx, project, query
func (_m *FakeManager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) (int64, []*policy.Schema, error) {
func (_m *FakeManager) ListPoliciesByProject(ctx context.Context, project int64, query *q.Query) ([]*modelspolicy.Schema, error) {
ret := _m.Called(ctx, project, query)
var r0 int64
if rf, ok := ret.Get(0).(func(context.Context, int64, *q.Query) int64); ok {
var r0 []*modelspolicy.Schema
if rf, ok := ret.Get(0).(func(context.Context, int64, *q.Query) []*modelspolicy.Schema); ok {
r0 = rf(ctx, project, query)
} else {
r0 = ret.Get(0).(int64)
}
var r1 []*policy.Schema
if rf, ok := ret.Get(1).(func(context.Context, int64, *q.Query) []*policy.Schema); ok {
r1 = rf(ctx, project, query)
} else {
if ret.Get(1) != nil {
r1 = ret.Get(1).([]*policy.Schema)
if ret.Get(0) != nil {
r0 = ret.Get(0).([]*modelspolicy.Schema)
}
}
var r2 error
if rf, ok := ret.Get(2).(func(context.Context, int64, *q.Query) error); ok {
r2 = rf(ctx, project, query)
var r1 error
if rf, ok := ret.Get(1).(func(context.Context, int64, *q.Query) error); ok {
r1 = rf(ctx, project, query)
} else {
r2 = ret.Error(2)
r1 = ret.Error(1)
}
return r0, r1, r2
return r0, r1
}
// Update provides a mock function with given fields: ctx, schema, props
func (_m *FakeManager) Update(ctx context.Context, schema *policy.Schema, props ...string) error {
func (_m *FakeManager) Update(ctx context.Context, schema *modelspolicy.Schema, props ...string) error {
_va := make([]interface{}, len(props))
for _i := range props {
_va[_i] = props[_i]
@ -146,7 +153,7 @@ func (_m *FakeManager) Update(ctx context.Context, schema *policy.Schema, props
ret := _m.Called(_ca...)
var r0 error
if rf, ok := ret.Get(0).(func(context.Context, *policy.Schema, ...string) error); ok {
if rf, ok := ret.Get(0).(func(context.Context, *modelspolicy.Schema, ...string) error); ok {
r0 = rf(ctx, schema, props...)
} else {
r0 = ret.Error(0)