mirror of
https://github.com/goharbor/harbor.git
synced 2025-02-16 20:01:35 +01:00
fix: wrap report vuls record creating in transaction (#14176)
Make the creating of the ReportVulnerabilityRecord in transaction to avoid parallel problem Closes #14171 Signed-off-by: He Weiwei <hweiwei@vmware.com>
This commit is contained in:
parent
de97b900cf
commit
44ba7de738
@ -54,8 +54,7 @@ func AsNotFoundError(err error, messageFormat string, args ...interface{}) *erro
|
||||
// AsConflictError checks whether the err is duplicate key error. If it it, wrap it
|
||||
// as a src/internal/error.Error with conflict error code, else return nil
|
||||
func AsConflictError(err error, messageFormat string, args ...interface{}) *errors.Error {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
if isDuplicateKeyError(err) {
|
||||
e := errors.New(err).
|
||||
WithCode(errors.ConflictCode).
|
||||
WithMessage(messageFormat, args...)
|
||||
@ -67,8 +66,7 @@ func AsConflictError(err error, messageFormat string, args ...interface{}) *erro
|
||||
// AsForeignKeyError checks whether the err is violating foreign key constraint error. If it it, wrap it
|
||||
// as a src/internal/error.Error with violating foreign key constraint error code, else return nil
|
||||
func AsForeignKeyError(err error, messageFormat string, args ...interface{}) *errors.Error {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
|
||||
if isViolatingForeignKeyConstraintError(err) {
|
||||
e := errors.New(err).
|
||||
WithCode(errors.ViolateForeignKeyConstraintCode).
|
||||
WithMessage(messageFormat, args...)
|
||||
@ -76,3 +74,21 @@ func AsForeignKeyError(err error, messageFormat string, args ...interface{}) *er
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isDuplicateKeyError(err error) bool {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23505" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
func isViolatingForeignKeyConstraintError(err error) bool {
|
||||
var pqErr *pq.Error
|
||||
if errors.As(err, &pqErr) && pqErr.Code == "23503" {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ package orm
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/astaxie/beego/orm"
|
||||
"github.com/goharbor/harbor/src/lib/log"
|
||||
@ -87,3 +88,61 @@ func WithTransaction(f func(ctx context.Context) error) func(ctx context.Context
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// ReadOrCreate read or create instance to datebase, retry to read when met a duplicate key error after the creating
|
||||
func ReadOrCreate(ctx context.Context, md interface{}, col1 string, cols ...string) (created bool, id int64, err error) {
|
||||
getter, ok := md.(interface {
|
||||
GetID() int64
|
||||
})
|
||||
|
||||
if !ok {
|
||||
err = fmt.Errorf("missing GetID method for the model %T", md)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if !created && err == nil { // found in the database
|
||||
id = getter.GetID()
|
||||
}
|
||||
}()
|
||||
|
||||
o, err := FromContext(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
cols = append([]string{col1}, cols...)
|
||||
|
||||
err = o.Read(md, cols...)
|
||||
if err == nil { // found in the database
|
||||
return
|
||||
}
|
||||
|
||||
if !errors.Is(err, orm.ErrNoRows) { // met a error when read database
|
||||
return
|
||||
}
|
||||
|
||||
// not found in the database, try to create one
|
||||
err = WithTransaction(func(ctx context.Context) error {
|
||||
o, err := FromContext(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
id, err = o.Insert(md)
|
||||
return err
|
||||
})(ctx)
|
||||
|
||||
if err == nil { // create success
|
||||
created = true
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// got a duplicate key error, try to read again
|
||||
if isDuplicateKeyError(err) {
|
||||
err = o.Read(md, cols...)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
@ -17,6 +17,7 @@ package orm
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/astaxie/beego/orm"
|
||||
@ -29,10 +30,14 @@ type Foo struct {
|
||||
Name string `orm:"column(name)"`
|
||||
}
|
||||
|
||||
func (*Foo) TableName() string {
|
||||
func (foo *Foo) TableName() string {
|
||||
return "foo"
|
||||
}
|
||||
|
||||
func (foo *Foo) GetID() int64 {
|
||||
return foo.ID
|
||||
}
|
||||
|
||||
func addFoo(ctx context.Context, foo Foo) (int64, error) {
|
||||
o, err := FromContext(ctx)
|
||||
if err != nil {
|
||||
@ -349,6 +354,61 @@ func (suite *OrmSuite) TestNestedSavepoint() {
|
||||
suite.False(existFoo(ctx, id2))
|
||||
}
|
||||
|
||||
func (suite *OrmSuite) TestReadOrCreate() {
|
||||
ctx := NewContext(context.TODO(), orm.NewOrm())
|
||||
|
||||
var id int64
|
||||
f1 := func(ctx context.Context) (err error) {
|
||||
created1, id1, err := ReadOrCreate(ctx, &Foo{Name: "n1"}, "name")
|
||||
suite.NoError(err)
|
||||
suite.True(created1)
|
||||
|
||||
created2, id2, err := ReadOrCreate(ctx, &Foo{Name: "n1"}, "name")
|
||||
suite.NoError(err)
|
||||
suite.False(created2)
|
||||
|
||||
suite.Equal(id2, id1)
|
||||
|
||||
id = id1
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
suite.NoError(WithTransaction(f1)(ctx))
|
||||
suite.True(existFoo(ctx, id))
|
||||
}
|
||||
|
||||
func (suite *OrmSuite) TestReadOrCreateParallel() {
|
||||
count := 500
|
||||
|
||||
arr := make([]int, count)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for i := 0; i < count; i++ {
|
||||
wg.Add(1)
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
ctx := NewContext(context.TODO(), orm.NewOrm())
|
||||
created, _, err := ReadOrCreate(ctx, &Foo{Name: "n2"}, "name")
|
||||
suite.NoError(err)
|
||||
|
||||
if created {
|
||||
arr[i] = 1
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
sum := 0
|
||||
for _, v := range arr {
|
||||
sum += v
|
||||
}
|
||||
|
||||
suite.Equal(1, sum)
|
||||
}
|
||||
|
||||
func TestRunOrmSuite(t *testing.T) {
|
||||
suite.Run(t, new(OrmSuite))
|
||||
}
|
||||
|
@ -70,6 +70,23 @@ type VulnerabilityRecord struct {
|
||||
VendorAttributes string `orm:"column(vendor_attributes);type(json);null"`
|
||||
}
|
||||
|
||||
// TableName for VulnerabilityRecord
|
||||
func (vr *VulnerabilityRecord) TableName() string {
|
||||
return "vulnerability_record"
|
||||
}
|
||||
|
||||
// TableUnique for VulnerabilityRecord
|
||||
func (vr *VulnerabilityRecord) TableUnique() [][]string {
|
||||
return [][]string{
|
||||
{"cve_id", "registration_uuid", "package", "package_version"},
|
||||
}
|
||||
}
|
||||
|
||||
// GetID returns the ID of the record
|
||||
func (vr *VulnerabilityRecord) GetID() int64 {
|
||||
return vr.ID
|
||||
}
|
||||
|
||||
// ReportVulnerabilityRecord is relation table required to optimize data storage for both the
|
||||
// vulnerability records and the scan report.
|
||||
// identified by composite key (ID, Report)
|
||||
@ -83,18 +100,6 @@ type ReportVulnerabilityRecord struct {
|
||||
VulnRecordID int64 `orm:"column(vuln_record_id);"`
|
||||
}
|
||||
|
||||
// TableName for VulnerabilityRecord
|
||||
func (vr *VulnerabilityRecord) TableName() string {
|
||||
return "vulnerability_record"
|
||||
}
|
||||
|
||||
// TableUnique for VulnerabilityRecord
|
||||
func (vr *VulnerabilityRecord) TableUnique() [][]string {
|
||||
return [][]string{
|
||||
{"cve_id", "registration_uuid", "package", "package_version"},
|
||||
}
|
||||
}
|
||||
|
||||
// TableName for ReportVulnerabilityRecord
|
||||
func (rvr *ReportVulnerabilityRecord) TableName() string {
|
||||
return "report_vulnerability_record"
|
||||
@ -106,3 +111,8 @@ func (rvr *ReportVulnerabilityRecord) TableUnique() [][]string {
|
||||
{"report_uuid", "vuln_record_id"},
|
||||
}
|
||||
}
|
||||
|
||||
// GetID returns the ID of the record
|
||||
func (rvr *ReportVulnerabilityRecord) GetID() int64 {
|
||||
return rvr.ID
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ package scan
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"github.com/goharbor/harbor/src/lib/errors"
|
||||
|
||||
"github.com/goharbor/harbor/src/lib/orm"
|
||||
"github.com/goharbor/harbor/src/lib/q"
|
||||
)
|
||||
@ -60,19 +60,8 @@ type vulnerabilityRecordDao struct{}
|
||||
|
||||
// Create creates new vulnerability record.
|
||||
func (v *vulnerabilityRecordDao) Create(ctx context.Context, vr *VulnerabilityRecord) (int64, error) {
|
||||
o, err := orm.FromContext(ctx)
|
||||
var vrID int64
|
||||
err = orm.WithTransaction(func(ctx context.Context) error {
|
||||
var err error
|
||||
vrID, err = o.InsertOrUpdate(vr, "cve_id, registration_uuid, package, package_version")
|
||||
return orm.WrapConflictError(err, "vulnerability already exists")
|
||||
})(ctx)
|
||||
if errors.IsConflictErr(err) {
|
||||
if err := o.Read(vr, "cve_id", "registration_uuid", "package", "package_version"); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return vr.ID, nil
|
||||
}
|
||||
_, vrID, err := orm.ReadOrCreate(ctx, vr, "cve_id", "registration_uuid", "package", "package_version")
|
||||
|
||||
return vrID, err
|
||||
}
|
||||
|
||||
@ -137,11 +126,7 @@ func (v *vulnerabilityRecordDao) InsertForReport(ctx context.Context, reportUUID
|
||||
rvr.Report = reportUUID
|
||||
rvr.VulnRecordID = vrID
|
||||
|
||||
o, err := orm.FromContext(ctx)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
_, rvrID, err := o.ReadOrCreate(rvr, "report_uuid", "vuln_record_id")
|
||||
_, rvrID, err := orm.ReadOrCreate(ctx, rvr, "report_uuid", "vuln_record_id")
|
||||
|
||||
return rvrID, err
|
||||
|
||||
@ -164,8 +149,8 @@ func (v *vulnerabilityRecordDao) GetForReport(ctx context.Context, reportUUID st
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query := `select vulnerability_record.* from vulnerability_record
|
||||
inner join report_vulnerability_record on
|
||||
query := `select vulnerability_record.* from vulnerability_record
|
||||
inner join report_vulnerability_record on
|
||||
vulnerability_record.id = report_vulnerability_record.vuln_record_id and report_vulnerability_record.report_uuid=?`
|
||||
_, err = o.Raw(query, reportUUID).QueryRows(&vulnRecs)
|
||||
return vulnRecs, err
|
||||
|
@ -1,17 +1,32 @@
|
||||
// Copyright Project Harbor Authors
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
package scan
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/goharbor/harbor/src/jobservice/job"
|
||||
"github.com/goharbor/harbor/src/lib/orm"
|
||||
"github.com/goharbor/harbor/src/lib/q"
|
||||
"github.com/goharbor/harbor/src/pkg/scan/dao/scanner"
|
||||
"github.com/goharbor/harbor/src/pkg/scan/rest/v1"
|
||||
v1 "github.com/goharbor/harbor/src/pkg/scan/rest/v1"
|
||||
htesting "github.com/goharbor/harbor/src/testing"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/stretchr/testify/suite"
|
||||
"testing"
|
||||
)
|
||||
|
||||
const sampleReportWithCompleteVulnData = `{
|
||||
|
Loading…
Reference in New Issue
Block a user