fix: wrap report vuls record creating in transaction (#14176)

Make the creating of the ReportVulnerabilityRecord in transaction to
avoid parallel problem

Closes #14171

Signed-off-by: He Weiwei <hweiwei@vmware.com>
This commit is contained in:
He Weiwei 2021-02-05 12:15:52 +08:00 committed by GitHub
parent de97b900cf
commit 44ba7de738
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 185 additions and 40 deletions

View File

@ -54,8 +54,7 @@ func AsNotFoundError(err error, messageFormat string, args ...interface{}) *erro
// AsConflictError checks whether the err is duplicate key error. If it it, wrap it
// as a src/internal/error.Error with conflict error code, else return nil
func AsConflictError(err error, messageFormat string, args ...interface{}) *errors.Error {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
if isDuplicateKeyError(err) {
e := errors.New(err).
WithCode(errors.ConflictCode).
WithMessage(messageFormat, args...)
@ -67,8 +66,7 @@ func AsConflictError(err error, messageFormat string, args ...interface{}) *erro
// AsForeignKeyError checks whether the err is violating foreign key constraint error. If it it, wrap it
// as a src/internal/error.Error with violating foreign key constraint error code, else return nil
func AsForeignKeyError(err error, messageFormat string, args ...interface{}) *errors.Error {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
if isViolatingForeignKeyConstraintError(err) {
e := errors.New(err).
WithCode(errors.ViolateForeignKeyConstraintCode).
WithMessage(messageFormat, args...)
@ -76,3 +74,21 @@ func AsForeignKeyError(err error, messageFormat string, args ...interface{}) *er
}
return nil
}
func isDuplicateKeyError(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
return true
}
return false
}
func isViolatingForeignKeyConstraintError(err error) bool {
var pqErr *pq.Error
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
return true
}
return false
}

View File

@ -17,6 +17,7 @@ package orm
import (
"context"
"errors"
"fmt"
"github.com/astaxie/beego/orm"
"github.com/goharbor/harbor/src/lib/log"
@ -87,3 +88,61 @@ func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context
return nil
}
}
// ReadOrCreate read or create instance to datebase, retry to read when met a duplicate key error after the creating
func ReadOrCreate(ctx context.Context, md interface{}, col1 string, cols ...string) (created bool, id int64, err error) {
getter, ok := md.(interface {
GetID() int64
})
if !ok {
err = fmt.Errorf("missing GetID method for the model %T", md)
return
}
defer func() {
if !created && err == nil { // found in the database
id = getter.GetID()
}
}()
o, err := FromContext(ctx)
if err != nil {
return
}
cols = append([]string{col1}, cols...)
err = o.Read(md, cols...)
if err == nil { // found in the database
return
}
if !errors.Is(err, orm.ErrNoRows) { // met a error when read database
return
}
// not found in the database, try to create one
err = WithTransaction(func(ctx context.Context) error {
o, err := FromContext(ctx)
if err != nil {
return err
}
id, err = o.Insert(md)
return err
})(ctx)
if err == nil { // create success
created = true
return
}
// got a duplicate key error, try to read again
if isDuplicateKeyError(err) {
err = o.Read(md, cols...)
}
return
}

View File

@ -17,6 +17,7 @@ package orm
import (
"context"
"errors"
"sync"
"testing"
"github.com/astaxie/beego/orm"
@ -29,10 +30,14 @@ type Foo struct {
Name string `orm:"column(name)"`
}
func (*Foo) TableName() string {
func (foo *Foo) TableName() string {
return "foo"
}
func (foo *Foo) GetID() int64 {
return foo.ID
}
func addFoo(ctx context.Context, foo Foo) (int64, error) {
o, err := FromContext(ctx)
if err != nil {
@ -349,6 +354,61 @@ func (suite *OrmSuite) TestNestedSavepoint() {
suite.False(existFoo(ctx, id2))
}
func (suite *OrmSuite) TestReadOrCreate() {
ctx := NewContext(context.TODO(), orm.NewOrm())
var id int64
f1 := func(ctx context.Context) (err error) {
created1, id1, err := ReadOrCreate(ctx, &Foo{Name: "n1"}, "name")
suite.NoError(err)
suite.True(created1)
created2, id2, err := ReadOrCreate(ctx, &Foo{Name: "n1"}, "name")
suite.NoError(err)
suite.False(created2)
suite.Equal(id2, id1)
id = id1
return nil
}
suite.NoError(WithTransaction(f1)(ctx))
suite.True(existFoo(ctx, id))
}
func (suite *OrmSuite) TestReadOrCreateParallel() {
count := 500
arr := make([]int, count)
var wg sync.WaitGroup
for i := 0; i < count; i++ {
wg.Add(1)
go func(i int) {
defer wg.Done()
ctx := NewContext(context.TODO(), orm.NewOrm())
created, _, err := ReadOrCreate(ctx, &Foo{Name: "n2"}, "name")
suite.NoError(err)
if created {
arr[i] = 1
}
}(i)
}
wg.Wait()
sum := 0
for _, v := range arr {
sum += v
}
suite.Equal(1, sum)
}
func TestRunOrmSuite(t *testing.T) {
suite.Run(t, new(OrmSuite))
}

View File

@ -70,6 +70,23 @@ type VulnerabilityRecord struct {
VendorAttributes string `orm:"column(vendor_attributes);type(json);null"`
}
// TableName for VulnerabilityRecord
func (vr *VulnerabilityRecord) TableName() string {
return "vulnerability_record"
}
// TableUnique for VulnerabilityRecord
func (vr *VulnerabilityRecord) TableUnique() [][]string {
return [][]string{
{"cve_id", "registration_uuid", "package", "package_version"},
}
}
// GetID returns the ID of the record
func (vr *VulnerabilityRecord) GetID() int64 {
return vr.ID
}
// ReportVulnerabilityRecord is relation table required to optimize data storage for both the
// vulnerability records and the scan report.
// identified by composite key (ID, Report)
@ -83,18 +100,6 @@ type ReportVulnerabilityRecord struct {
VulnRecordID int64 `orm:"column(vuln_record_id);"`
}
// TableName for VulnerabilityRecord
func (vr *VulnerabilityRecord) TableName() string {
return "vulnerability_record"
}
// TableUnique for VulnerabilityRecord
func (vr *VulnerabilityRecord) TableUnique() [][]string {
return [][]string{
{"cve_id", "registration_uuid", "package", "package_version"},
}
}
// TableName for ReportVulnerabilityRecord
func (rvr *ReportVulnerabilityRecord) TableName() string {
return "report_vulnerability_record"
@ -106,3 +111,8 @@ func (rvr *ReportVulnerabilityRecord) TableUnique() [][]string {
{"report_uuid", "vuln_record_id"},
}
}
// GetID returns the ID of the record
func (rvr *ReportVulnerabilityRecord) GetID() int64 {
return rvr.ID
}

View File

@ -17,7 +17,7 @@ package scan
import (
"context"
"fmt"
"github.com/goharbor/harbor/src/lib/errors"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
)
@ -60,19 +60,8 @@ type vulnerabilityRecordDao struct{}
// Create creates new vulnerability record.
func (v *vulnerabilityRecordDao) Create(ctx context.Context, vr *VulnerabilityRecord) (int64, error) {
o, err := orm.FromContext(ctx)
var vrID int64
err = orm.WithTransaction(func(ctx context.Context) error {
var err error
vrID, err = o.InsertOrUpdate(vr, "cve_id, registration_uuid, package, package_version")
return orm.WrapConflictError(err, "vulnerability already exists")
})(ctx)
if errors.IsConflictErr(err) {
if err := o.Read(vr, "cve_id", "registration_uuid", "package", "package_version"); err != nil {
return 0, err
}
return vr.ID, nil
}
_, vrID, err := orm.ReadOrCreate(ctx, vr, "cve_id", "registration_uuid", "package", "package_version")
return vrID, err
}
@ -137,11 +126,7 @@ func (v *vulnerabilityRecordDao) InsertForReport(ctx context.Context, reportUUID
rvr.Report = reportUUID
rvr.VulnRecordID = vrID
o, err := orm.FromContext(ctx)
if err != nil {
return 0, err
}
_, rvrID, err := o.ReadOrCreate(rvr, "report_uuid", "vuln_record_id")
_, rvrID, err := orm.ReadOrCreate(ctx, rvr, "report_uuid", "vuln_record_id")
return rvrID, err
@ -164,8 +149,8 @@ func (v *vulnerabilityRecordDao) GetForReport(ctx context.Context, reportUUID st
if err != nil {
return nil, err
}
query := `select vulnerability_record.* from vulnerability_record
inner join report_vulnerability_record on
query := `select vulnerability_record.* from vulnerability_record
inner join report_vulnerability_record on
vulnerability_record.id = report_vulnerability_record.vuln_record_id and report_vulnerability_record.report_uuid=?`
_, err = o.Raw(query, reportUUID).QueryRows(&vulnRecs)
return vulnRecs, err

View File

@ -1,17 +1,32 @@
// Copyright Project Harbor Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package scan
import (
"fmt"
"testing"
"github.com/goharbor/harbor/src/jobservice/job"
"github.com/goharbor/harbor/src/lib/orm"
"github.com/goharbor/harbor/src/lib/q"
"github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
"github.com/goharbor/harbor/src/pkg/scan/rest/v1"
v1 "github.com/goharbor/harbor/src/pkg/scan/rest/v1"
htesting "github.com/goharbor/harbor/src/testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/stretchr/testify/suite"
"testing"
)
const sampleReportWithCompleteVulnData = `{