From 130452111b628a85a66a59985fb561f0bf268511 Mon Sep 17 00:00:00 2001 From: prahaladdarkin Date: Mon, 11 Jul 2022 14:05:04 +0530 Subject: [PATCH] Vulnerability scan data export functionality (#15998) Vulnerability Scan Data (CVE) Export Functionality Proposal - goharbor/community#174 Closes - https://github.com/goharbor/harbor/issues/17150 Changes: * CVE Data export to CSV with filtering support. * Implement CSV data export job for creating CSVs * APIs to trigger CSV export job executions Signed-off-by: prahaladdarkin --- api/v2.0/swagger.yaml | 214 ++++++ make/photon/jobservice/Dockerfile | 2 +- .../docker_compose/docker-compose.yml.jinja | 1 + src/controller/scandataexport/execution.go | 194 ++++++ .../scandataexport/execution_test.go | 332 +++++++++ src/go.mod | 1 + src/go.sum | 3 + .../impl/scandataexport/scan_data_export.go | 341 +++++++++ .../scandataexport/scan_data_export_test.go | 649 ++++++++++++++++++ src/jobservice/job/known_jobs.go | 2 + src/jobservice/runtime/bootstrap.go | 2 + src/pkg/scan/export/constants.go | 15 + src/pkg/scan/export/digest_calculator.go | 28 + src/pkg/scan/export/digest_calculator_test.go | 32 + src/pkg/scan/export/export_data_selector.go | 60 ++ .../scan/export/export_data_selector_test.go | 123 ++++ src/pkg/scan/export/filter_processor.go | 142 ++++ src/pkg/scan/export/filter_processor_test.go | 233 +++++++ src/pkg/scan/export/manager.go | 233 +++++++ src/pkg/scan/export/manager_test.go | 371 ++++++++++ src/pkg/scan/export/model.go | 115 ++++ src/pkg/systemartifact/manager.go | 11 +- src/pkg/systemartifact/manager_test.go | 24 +- src/server/v2.0/handler/handler.go | 1 + src/server/v2.0/handler/scanexport.go | 237 +++++++ src/server/v2.0/handler/scanexport_test.go | 396 +++++++++++ src/testing/controller/controller.go | 1 + .../controller/scandataexport/controller.go | 136 ++++ src/testing/pkg/pkg.go | 4 + src/testing/pkg/registry/client.go | 2 +- .../pkg/registry/fake_registry_client.go | 325 +++++++++ .../scan/dao/scan/vulnerability_record_dao.go | 256 +++++++ .../scan/export/artifact_digest_calculator.go | 50 ++ .../pkg/scan/export/filter_processor.go | 100 +++ src/testing/pkg/scan/export/manager.go | 53 ++ .../github.com/gocarina/gocsv/.gitignore | 1 + .../github.com/gocarina/gocsv/.travis.yml | 4 + src/vendor/github.com/gocarina/gocsv/LICENSE | 21 + .../github.com/gocarina/gocsv/README.md | 168 +++++ src/vendor/github.com/gocarina/gocsv/csv.go | 517 ++++++++++++++ .../github.com/gocarina/gocsv/decode.go | 463 +++++++++++++ .../github.com/gocarina/gocsv/encode.go | 147 ++++ src/vendor/github.com/gocarina/gocsv/go.mod | 3 + .../github.com/gocarina/gocsv/reflect.go | 143 ++++ .../github.com/gocarina/gocsv/safe_csv.go | 38 + src/vendor/github.com/gocarina/gocsv/types.go | 456 ++++++++++++ .../github.com/gocarina/gocsv/unmarshaller.go | 117 ++++ src/vendor/modules.txt | 3 + 48 files changed, 6762 insertions(+), 8 deletions(-) create mode 100644 src/controller/scandataexport/execution.go create mode 100644 src/controller/scandataexport/execution_test.go create mode 100644 src/jobservice/job/impl/scandataexport/scan_data_export.go create mode 100644 src/jobservice/job/impl/scandataexport/scan_data_export_test.go create mode 100644 src/pkg/scan/export/constants.go create mode 100644 src/pkg/scan/export/digest_calculator.go create mode 100644 src/pkg/scan/export/digest_calculator_test.go create mode 100644 src/pkg/scan/export/export_data_selector.go create mode 100644 src/pkg/scan/export/export_data_selector_test.go create mode 100644 src/pkg/scan/export/filter_processor.go create mode 100644 src/pkg/scan/export/filter_processor_test.go create mode 100644 src/pkg/scan/export/manager.go create mode 100644 src/pkg/scan/export/manager_test.go create mode 100644 src/pkg/scan/export/model.go create mode 100644 src/server/v2.0/handler/scanexport.go create mode 100644 src/server/v2.0/handler/scanexport_test.go create mode 100644 src/testing/controller/scandataexport/controller.go create mode 100644 src/testing/pkg/registry/fake_registry_client.go create mode 100644 src/testing/pkg/scan/dao/scan/vulnerability_record_dao.go create mode 100644 src/testing/pkg/scan/export/artifact_digest_calculator.go create mode 100644 src/testing/pkg/scan/export/filter_processor.go create mode 100644 src/testing/pkg/scan/export/manager.go create mode 100644 src/vendor/github.com/gocarina/gocsv/.gitignore create mode 100644 src/vendor/github.com/gocarina/gocsv/.travis.yml create mode 100644 src/vendor/github.com/gocarina/gocsv/LICENSE create mode 100644 src/vendor/github.com/gocarina/gocsv/README.md create mode 100644 src/vendor/github.com/gocarina/gocsv/csv.go create mode 100644 src/vendor/github.com/gocarina/gocsv/decode.go create mode 100644 src/vendor/github.com/gocarina/gocsv/encode.go create mode 100644 src/vendor/github.com/gocarina/gocsv/go.mod create mode 100644 src/vendor/github.com/gocarina/gocsv/reflect.go create mode 100644 src/vendor/github.com/gocarina/gocsv/safe_csv.go create mode 100644 src/vendor/github.com/gocarina/gocsv/types.go create mode 100644 src/vendor/github.com/gocarina/gocsv/unmarshaller.go diff --git a/api/v2.0/swagger.yaml b/api/v2.0/swagger.yaml index ca5ee6094..90929ed72 100644 --- a/api/v2.0/swagger.yaml +++ b/api/v2.0/swagger.yaml @@ -5595,6 +5595,122 @@ paths: '500': $ref: '#/responses/500' + /export/cve: + post: + summary: Export scan data for selected projects + description: Export scan data for selected projects + tags: + - scan data export + operationId: exportScanData + parameters: + - $ref: '#/parameters/requestId' + - $ref: '#/parameters/scanDataType' + - name: criteria + in: body + description: The criteria for the export + required: true + schema: + $ref: '#/definitions/ScanDataExportRequest' + responses: + '200': + description: Success. + schema: + $ref: '#/definitions/ScanDataExportJob' + '400': + $ref: '#/responses/400' + '401': + $ref: '#/responses/401' + '403': + $ref: '#/responses/403' + '404': + $ref: '#/responses/404' + '405': + $ref: '#/responses/405' + '409': + $ref: '#/responses/409' + '500': + $ref: '#/responses/500' + /export/cve/execution/{execution_id}: + get: + summary: Get the specific scan data export execution + description: Get the scan data export execution specified by ID + tags: + - scan data export + operationId: getScanDataExportExecution + parameters: + - $ref: '#/parameters/requestId' + - $ref: '#/parameters/executionId' + responses: + '200': + description: Success + schema: + $ref: '#/definitions/ScanDataExportExecution' + '401': + $ref: '#/responses/401' + '403': + $ref: '#/responses/403' + '404': + $ref: '#/responses/404' + '500': + $ref: '#/responses/500' + /export/cve/executions: + get: + summary: Get a list of specific scan data export execution jobs for a specified user + description: Get the scan data export execution specified by ID + tags: + - scan data export + operationId: getScanDataExportExecutionList + parameters: + - $ref: '#/parameters/requestId' + - $ref: '#/parameters/userName' + responses: + '200': + description: Success + schema: + $ref: '#/definitions/ScanDataExportExecutionList' + '401': + $ref: '#/responses/401' + '403': + $ref: '#/responses/403' + '404': + $ref: '#/responses/404' + '500': + $ref: '#/responses/500' + /export/cve/download/{execution_id}: + get: + summary: Download the scan data export file + description: Download the scan data report. Default format is CSV + tags: + - scan data export + operationId: downloadScanData + produces: + - text/csv + parameters: + - $ref: '#/parameters/requestId' + - $ref: '#/parameters/executionId' + - name: format + in: query + type: string + required: false + description: The format of the data to be exported. e.g. CSV or PDF + responses: + '200': + description: Data file containing the export data + schema: + type: file + headers: + Content-Disposition: + type: string + description: Value is a CSV formatted file; filename=export.csv + '401': + $ref: '#/responses/401' + '403': + $ref: '#/responses/403' + '404': + $ref: '#/responses/404' + '500': + $ref: '#/responses/500' + parameters: query: name: q @@ -5762,6 +5878,18 @@ parameters: required: true type: integer format: int64 + scanDataType: + name: X-Scan-Data-Type + description: The type of scan data to export + in: header + type: string + required: true + userName: + name: user_name + description: The name of the user + in: query + type: string + required: true responses: '200': @@ -9036,3 +9164,89 @@ definitions: type: string format: date-time description: The creation time of the accessory + + ScanDataExportRequest: + type: object + description: The criteria to select the scan data to export. + properties: + job_name: + type: string + description: Name of the scan data export job + projects: + type: array + items: + type: integer + format: int64 + description: A list of one or more projects for which to export the scan data, defaults to all if empty + labels: + type: array + items: + type: integer + format: int64 + description: A list of one or more labels for which to export the scan data, defaults to all if empty + repositories: + type: string + description: A list of repositories for which to export the scan data, defaults to all if empty + cveIds: + type: string + description: CVE-IDs for which to export data. Multiple CVE-IDs can be specified by separating using ',' and enclosed between '{}'. Defaults to all if empty + tags: + type: string + description: A list of tags enclosed within '{}'. Defaults to all if empty + ScanDataExportJob: + type: object + description: The metadata associated with the scan data export job + properties: + id: + type: integer + format: int64 + description: The id of the scan data export job + ScanDataExportExecution: + type: object + description: The replication execution + properties: + id: + type: integer + description: The ID of the execution + user_id: + type: integer + description: The ID if the user triggering the export job + status: + type: string + description: The status of the execution + trigger: + type: string + description: The trigger mode + start_time: + type: string + format: date-time + description: The start time + end_time: + type: string + format: date-time + description: The end time + status_text: + type: string + x-omitempty: false + description: The status text + job_name: + type: string + x-omitempty: false + description: The name of the job as specified by the user + user_name: + type: string + x-omitempty: false + description: The name of the user triggering the job + file_present: + type: boolean + x-omitempty: false + description: Indicates whether the export artifact is present in registry + ScanDataExportExecutionList: + type: object + description: The list of scan data export executions + properties: + items: + type: array + items: + $ref: '#/definitions/ScanDataExportExecution' + description: The list of scan data export executions \ No newline at end of file diff --git a/make/photon/jobservice/Dockerfile b/make/photon/jobservice/Dockerfile index accd7519d..e738a55c7 100644 --- a/make/photon/jobservice/Dockerfile +++ b/make/photon/jobservice/Dockerfile @@ -17,7 +17,7 @@ WORKDIR /harbor/ USER harbor -VOLUME ["/var/log/jobs/"] +VOLUME ["/var/log/jobs/", "/var/scandata_exports"] HEALTHCHECK CMD curl --fail -s http://localhost:8080/api/v1/stats || curl -sk --fail --key /etc/harbor/ssl/job_service.key --cert /etc/harbor/ssl/job_service.crt https://localhost:8443/api/v1/stats || exit 1 diff --git a/make/photon/prepare/templates/docker_compose/docker-compose.yml.jinja b/make/photon/prepare/templates/docker_compose/docker-compose.yml.jinja index 26d76a3a0..63f3205af 100644 --- a/make/photon/prepare/templates/docker_compose/docker-compose.yml.jinja +++ b/make/photon/prepare/templates/docker_compose/docker-compose.yml.jinja @@ -253,6 +253,7 @@ services: - SETUID volumes: - {{data_volume}}/job_logs:/var/log/jobs:z + - {{data_volume}}/scandata_exports:/var/scandata_exports:z - type: bind source: ./common/config/jobservice/config.yml target: /etc/jobservice/config.yml diff --git a/src/controller/scandataexport/execution.go b/src/controller/scandataexport/execution.go new file mode 100644 index 000000000..fe3796162 --- /dev/null +++ b/src/controller/scandataexport/execution.go @@ -0,0 +1,194 @@ +package scandataexport + +import ( + "context" + "fmt" + "strings" + "time" + + "github.com/goharbor/harbor/src/jobservice/job" + "github.com/goharbor/harbor/src/jobservice/logger" + "github.com/goharbor/harbor/src/lib/errors" + "github.com/goharbor/harbor/src/lib/log" + "github.com/goharbor/harbor/src/lib/orm" + q2 "github.com/goharbor/harbor/src/lib/q" + "github.com/goharbor/harbor/src/pkg/scan/export" + "github.com/goharbor/harbor/src/pkg/systemartifact" + "github.com/goharbor/harbor/src/pkg/task" +) + +func init() { + task.SetExecutionSweeperCount(job.ScanDataExport, 50) +} + +var Ctl = NewController() + +type Controller interface { + Start(ctx context.Context, criteria export.Request) (executionID int64, err error) + GetExecution(ctx context.Context, executionID int64) (*export.Execution, error) + ListExecutions(ctx context.Context, userName string) ([]*export.Execution, error) + GetTask(ctx context.Context, executionID int64) (*task.Task, error) + DeleteExecution(ctx context.Context, executionID int64) error +} + +func NewController() Controller { + return &controller{ + execMgr: task.ExecMgr, + taskMgr: task.Mgr, + makeCtx: orm.Context, + sysArtifactMgr: systemartifact.Mgr, + } +} + +type controller struct { + execMgr task.ExecutionManager + taskMgr task.Manager + makeCtx func() context.Context + sysArtifactMgr systemartifact.Manager +} + +func (c *controller) ListExecutions(ctx context.Context, userName string) ([]*export.Execution, error) { + keywords := make(map[string]interface{}) + keywords["VendorType"] = job.ScanDataExport + keywords[fmt.Sprintf("ExtraAttrs.%s", export.UserNameAttribute)] = userName + + q := q2.New(q2.KeyWords{}) + q.Keywords = keywords + execsForUser, err := c.execMgr.List(ctx, q) + if err != nil { + return nil, err + } + execs := make([]*export.Execution, 0) + for _, execForUser := range execsForUser { + execs = append(execs, c.convertToExportExecStatus(ctx, execForUser)) + } + return execs, nil +} + +func (c *controller) GetTask(ctx context.Context, executionID int64) (*task.Task, error) { + query := q2.New(q2.KeyWords{}) + + keywords := make(map[string]interface{}) + keywords["VendorType"] = job.ScanDataExport + keywords["ExecutionID"] = executionID + query.Keywords = keywords + query.Sorts = append(query.Sorts, &q2.Sort{ + Key: "ID", + DESC: true, + }) + tasks, err := c.taskMgr.List(ctx, query) + if err != nil { + return nil, err + } + if len(tasks) == 0 { + return nil, errors.Errorf("No task found for execution Id : %d", executionID) + } + // for the export JOB there would be a single instance of the task corresponding to the execution + // we will hence return the latest instance of the task associated with this execution + logger.Infof("Returning task instance with ID : %d", tasks[0].ID) + return tasks[0], nil +} + +func (c *controller) GetExecution(ctx context.Context, executionID int64) (*export.Execution, error) { + exec, err := c.execMgr.Get(ctx, executionID) + if err != nil { + logger.Errorf("Error when fetching execution status for ExecutionId: %d error : %v", executionID, err) + return nil, err + } + if exec == nil { + logger.Infof("No execution found for ExecutionId: %d", executionID) + return nil, nil + } + return c.convertToExportExecStatus(ctx, exec), nil +} + +func (c *controller) DeleteExecution(ctx context.Context, executionID int64) error { + err := c.execMgr.Delete(ctx, executionID) + if err != nil { + logger.Errorf("Error when deleting execution for ExecutionId: %d, error : %v", executionID, err) + } + return err +} + +func (c *controller) Start(ctx context.Context, request export.Request) (executionID int64, err error) { + logger := log.GetLogger(ctx) + vendorID := int64(ctx.Value(export.CsvJobVendorIDKey).(int)) + extraAttrs := make(map[string]interface{}) + extraAttrs[export.JobNameAttribute] = request.JobName + extraAttrs[export.UserNameAttribute] = request.UserName + id, err := c.execMgr.Create(ctx, job.ScanDataExport, vendorID, task.ExecutionTriggerManual, extraAttrs) + logger.Infof("Created an execution record with id : %d for vendorID: %d", id, vendorID) + if err != nil { + logger.Errorf("Encountered error when creating job : %v", err) + return 0, err + } + + // create a job object and fill with metadata and parameters + params := make(map[string]interface{}) + params["JobId"] = id + params["Request"] = request + params[export.JobModeKey] = export.JobModeExport + + j := &task.Job{ + Name: job.ScanDataExport, + Metadata: &job.Metadata{ + JobKind: job.KindGeneric, + }, + Parameters: params, + } + + _, err = c.taskMgr.Create(ctx, id, j) + + if err != nil { + logger.Errorf("Unable to create a scan data export job: %v", err) + c.markError(ctx, id, err) + return 0, err + } + + logger.Info("Created job for scan data export successfully") + return id, nil +} + +func (c *controller) markError(ctx context.Context, executionID int64, err error) { + logger := log.GetLogger(ctx) + // try to stop the execution first in case that some tasks are already created + if err := c.execMgr.StopAndWait(ctx, executionID, 10*time.Second); err != nil { + logger.Errorf("failed to stop the execution %d: %v", executionID, err) + } + if err := c.execMgr.MarkError(ctx, executionID, err.Error()); err != nil { + logger.Errorf("failed to mark error for the execution %d: %v", executionID, err) + } +} + +func (c *controller) convertToExportExecStatus(ctx context.Context, exec *task.Execution) *export.Execution { + execStatus := &export.Execution{ + ID: exec.ID, + UserID: exec.VendorID, + Status: exec.Status, + StatusMessage: exec.StatusMessage, + Trigger: exec.Trigger, + StartTime: exec.StartTime, + EndTime: exec.EndTime, + } + if digest, ok := exec.ExtraAttrs[export.DigestKey]; ok { + execStatus.ExportDataDigest = digest.(string) + } + if jobName, ok := exec.ExtraAttrs[export.JobNameAttribute]; ok { + execStatus.JobName = jobName.(string) + } + if userName, ok := exec.ExtraAttrs[export.UserNameAttribute]; ok { + execStatus.UserName = userName.(string) + } + artifactExists := c.isCsvArtifactPresent(ctx, exec.ID, execStatus.ExportDataDigest) + execStatus.FilePresent = artifactExists + return execStatus +} + +func (c *controller) isCsvArtifactPresent(ctx context.Context, execID int64, digest string) bool { + repositoryName := fmt.Sprintf("scandata_export_%v", execID) + exists, err := c.sysArtifactMgr.Exists(ctx, strings.ToLower(export.Vendor), repositoryName, digest) + if err != nil { + exists = false + } + return exists +} diff --git a/src/controller/scandataexport/execution_test.go b/src/controller/scandataexport/execution_test.go new file mode 100644 index 000000000..75662a588 --- /dev/null +++ b/src/controller/scandataexport/execution_test.go @@ -0,0 +1,332 @@ +package scandataexport + +import ( + "context" + "github.com/goharbor/harbor/src/jobservice/job" + "github.com/goharbor/harbor/src/lib/orm" + "github.com/goharbor/harbor/src/pkg/scan/export" + "github.com/goharbor/harbor/src/pkg/task" + ormtesting "github.com/goharbor/harbor/src/testing/lib/orm" + "github.com/goharbor/harbor/src/testing/mock" + systemartifacttesting "github.com/goharbor/harbor/src/testing/pkg/systemartifact" + testingTask "github.com/goharbor/harbor/src/testing/pkg/task" + "github.com/pkg/errors" + testifymock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "testing" + "time" +) + +type ScanDataExportExecutionTestSuite struct { + suite.Suite + execMgr *testingTask.ExecutionManager + taskMgr *testingTask.Manager + sysArtifactMgr *systemartifacttesting.Manager + ctl *controller +} + +func (suite *ScanDataExportExecutionTestSuite) SetupSuite() { +} + +func (suite *ScanDataExportExecutionTestSuite) TestGetTask() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.sysArtifactMgr = &systemartifacttesting.Manager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + sysArtifactMgr: suite.sysArtifactMgr, + } + // valid task execution record exists for an execution id + { + t := task.Task{ + ID: 1, + VendorType: "SCAN_DATA_EXPORT", + ExecutionID: 100, + Status: "Success", + StatusMessage: "", + RunCount: 1, + JobID: "TestJobId", + ExtraAttrs: nil, + CreationTime: time.Time{}, + StartTime: time.Time{}, + UpdateTime: time.Time{}, + EndTime: time.Time{}, + StatusRevision: 0, + } + + tasks := make([]*task.Task, 0) + tasks = append(tasks, &t) + mock.OnAnything(suite.taskMgr, "List").Return(tasks, nil).Once() + returnedTask, err := suite.ctl.GetTask(context.Background(), 100) + suite.NoError(err) + suite.Equal(t, *returnedTask) + } + + // no task records exist for an execution id + { + tasks := make([]*task.Task, 0) + mock.OnAnything(suite.taskMgr, "List").Return(tasks, nil).Once() + _, err := suite.ctl.GetTask(context.Background(), 100) + suite.Error(err) + } + + // listing of tasks returns an error + { + mock.OnAnything(suite.taskMgr, "List").Return(nil, errors.New("test error")).Once() + _, err := suite.ctl.GetTask(context.Background(), 100) + suite.Error(err) + } + +} + +func (suite *ScanDataExportExecutionTestSuite) TestGetExecution() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.sysArtifactMgr = &systemartifacttesting.Manager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + sysArtifactMgr: suite.sysArtifactMgr, + } + // get execution succeeds + attrs := make(map[string]interface{}) + attrs[export.JobNameAttribute] = "test-job" + attrs[export.UserNameAttribute] = "test-user" + { + exec := task.Execution{ + ID: 100, + VendorType: "SCAN_DATA_EXPORT", + VendorID: -1, + Status: "Success", + StatusMessage: "", + Metrics: nil, + Trigger: "Manual", + ExtraAttrs: attrs, + StartTime: time.Time{}, + UpdateTime: time.Time{}, + EndTime: time.Time{}, + } + mock.OnAnything(suite.execMgr, "Get").Return(&exec, nil).Once() + mock.OnAnything(suite.sysArtifactMgr, "Exists").Return(true, nil).Once() + + exportExec, err := suite.ctl.GetExecution(context.TODO(), 100) + suite.NoError(err) + suite.Equal(exec.ID, exportExec.ID) + suite.Equal("test-user", exportExec.UserName) + suite.Equal("test-job", exportExec.JobName) + suite.Equal(true, exportExec.FilePresent) + } + + // get execution fails + { + mock.OnAnything(suite.execMgr, "Get").Return(nil, errors.New("test error")).Once() + exportExec, err := suite.ctl.GetExecution(context.TODO(), 100) + suite.Error(err) + suite.Nil(exportExec) + } + + // get execution returns null + { + mock.OnAnything(suite.execMgr, "Get").Return(nil, nil).Once() + exportExec, err := suite.ctl.GetExecution(context.TODO(), 100) + suite.NoError(err) + suite.Nil(exportExec) + } + +} + +func (suite *ScanDataExportExecutionTestSuite) TestGetExecutionSysArtifactExistFail() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.sysArtifactMgr = &systemartifacttesting.Manager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + sysArtifactMgr: suite.sysArtifactMgr, + } + // get execution succeeds + attrs := make(map[string]interface{}) + attrs[export.JobNameAttribute] = "test-job" + attrs[export.UserNameAttribute] = "test-user" + { + exec := task.Execution{ + ID: 100, + VendorType: "SCAN_DATA_EXPORT", + VendorID: -1, + Status: "Success", + StatusMessage: "", + Metrics: nil, + Trigger: "Manual", + ExtraAttrs: attrs, + StartTime: time.Time{}, + UpdateTime: time.Time{}, + EndTime: time.Time{}, + } + mock.OnAnything(suite.execMgr, "Get").Return(&exec, nil).Once() + mock.OnAnything(suite.sysArtifactMgr, "Exists").Return(false, errors.New("test error")).Once() + + exportExec, err := suite.ctl.GetExecution(context.TODO(), 100) + suite.NoError(err) + suite.Equal(exec.ID, exportExec.ID) + suite.Equal("test-user", exportExec.UserName) + suite.Equal("test-job", exportExec.JobName) + suite.Equal(false, exportExec.FilePresent) + } +} + +func (suite *ScanDataExportExecutionTestSuite) TestGetExecutionList() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.sysArtifactMgr = &systemartifacttesting.Manager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + sysArtifactMgr: suite.sysArtifactMgr, + } + // get execution succeeds + attrs := make(map[string]interface{}) + attrs[export.JobNameAttribute] = "test-job" + attrs[export.UserNameAttribute] = "test-user" + { + exec := task.Execution{ + ID: 100, + VendorType: "SCAN_DATA_EXPORT", + VendorID: -1, + Status: "Success", + StatusMessage: "", + Metrics: nil, + Trigger: "Manual", + ExtraAttrs: attrs, + StartTime: time.Time{}, + UpdateTime: time.Time{}, + EndTime: time.Time{}, + } + execs := make([]*task.Execution, 0) + execs = append(execs, &exec) + mock.OnAnything(suite.execMgr, "List").Return(execs, nil).Once() + mock.OnAnything(suite.sysArtifactMgr, "Exists").Return(true, nil).Once() + exportExec, err := suite.ctl.ListExecutions(context.TODO(), "test-user") + suite.NoError(err) + + suite.Equal(1, len(exportExec)) + suite.Equal("test-user", exportExec[0].UserName) + suite.Equal("test-job", exportExec[0].JobName) + } + + // get execution fails + { + mock.OnAnything(suite.execMgr, "List").Return(nil, errors.New("test error")).Once() + exportExec, err := suite.ctl.ListExecutions(context.TODO(), "test-user") + suite.Error(err) + suite.Nil(exportExec) + } +} + +func (suite *ScanDataExportExecutionTestSuite) TestStart() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + } + // execution manager and task manager return successfully + { + // get execution succeeds + attrs := make(map[string]interface{}) + attrs[export.JobNameAttribute] = "test-job" + attrs[export.UserNameAttribute] = "test-user" + suite.execMgr.On("Create", mock.Anything, mock.Anything, mock.Anything, mock.Anything, attrs).Return(int64(10), nil) + suite.taskMgr.On("Create", mock.Anything, mock.Anything, mock.Anything).Return(int64(20), nil) + ctx := context.Background() + ctx = context.WithValue(ctx, export.CsvJobVendorIDKey, int(-1)) + criteria := export.Request{} + criteria.UserName = "test-user" + criteria.JobName = "test-job" + executionId, err := suite.ctl.Start(ctx, criteria) + suite.NoError(err) + suite.Equal(int64(10), executionId) + suite.validateExecutionManagerInvocation(ctx) + } + +} + +func (suite *ScanDataExportExecutionTestSuite) TestDeleteExecution() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + } + mock.OnAnything(suite.execMgr, "Delete").Return(nil).Once() + err := suite.ctl.DeleteExecution(context.TODO(), int64(1)) + suite.NoError(err) +} + +func (suite *ScanDataExportExecutionTestSuite) TestStartWithExecManagerError() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + } + // execution manager returns an error + { + ctx := context.Background() + ctx = context.WithValue(ctx, export.CsvJobVendorIDKey, int(-1)) + mock.OnAnything(suite.execMgr, "Create").Return(int64(-1), errors.New("Test Error")) + _, err := suite.ctl.Start(ctx, export.Request{JobName: "test-job", UserName: "test-user"}) + suite.Error(err) + } +} + +func (suite *ScanDataExportExecutionTestSuite) TestStartWithTaskManagerError() { + suite.taskMgr = &testingTask.Manager{} + suite.execMgr = &testingTask.ExecutionManager{} + suite.ctl = &controller{ + execMgr: suite.execMgr, + taskMgr: suite.taskMgr, + makeCtx: func() context.Context { return orm.NewContext(nil, &ormtesting.FakeOrmer{}) }, + } + // execution manager is successful but task manager returns an error + // execution manager and task manager return successfully + { + ctx := context.Background() + ctx = context.WithValue(ctx, export.CsvJobVendorIDKey, int(-1)) + attrs := make(map[string]interface{}) + attrs[export.JobNameAttribute] = "test-job" + attrs[export.UserNameAttribute] = "test-user" + suite.execMgr.On("Create", mock.Anything, mock.Anything, mock.Anything, mock.Anything, attrs).Return(int64(10), nil) + suite.taskMgr.On("Create", mock.Anything, mock.Anything, mock.Anything).Return(int64(-1), errors.New("Test Error")) + mock.OnAnything(suite.execMgr, "StopAndWait").Return(nil) + mock.OnAnything(suite.execMgr, "MarkError").Return(nil) + _, err := suite.ctl.Start(ctx, export.Request{JobName: "test-job", UserName: "test-user"}) + suite.Error(err) + } +} + +func (suite *ScanDataExportExecutionTestSuite) TearDownSuite() { + suite.execMgr = nil + suite.taskMgr = nil +} + +func (suite *ScanDataExportExecutionTestSuite) validateExecutionManagerInvocation(ctx context.Context) { + // validate that execution manager has been called with the specified + extraAttsMatcher := testifymock.MatchedBy(func(m map[string]interface{}) bool { + jobName, jobNamePresent := m[export.JobNameAttribute] + userName, userNamePresent := m[export.UserNameAttribute] + return jobNamePresent && userNamePresent && jobName == "test-job" && userName == "test-user" + }) + suite.execMgr.AssertCalled(suite.T(), "Create", ctx, job.ScanDataExport, int64(-1), task.ExecutionTriggerManual, extraAttsMatcher) +} + +func TestScanDataExportExecutionTestSuite(t *testing.T) { + suite.Run(t, &ScanDataExportExecutionTestSuite{}) +} diff --git a/src/go.mod b/src/go.mod index d7d083705..f3e012ab6 100644 --- a/src/go.mod +++ b/src/go.mod @@ -28,6 +28,7 @@ require ( github.com/go-openapi/validate v0.19.10 github.com/go-redis/redis/v8 v8.11.4 github.com/go-sql-driver/mysql v1.5.0 + github.com/gocarina/gocsv v0.0.0-20210516172204-ca9e8a8ddea8 github.com/gocraft/work v0.5.1 github.com/golang-jwt/jwt/v4 v4.1.0 github.com/golang-migrate/migrate/v4 v4.15.1 diff --git a/src/go.sum b/src/go.sum index ae1c611e0..2716f79b1 100644 --- a/src/go.sum +++ b/src/go.sum @@ -638,6 +638,9 @@ github.com/gobuffalo/packr/v2 v2.8.1/go.mod h1:c/PLlOuTU+p3SybaJATW3H6lX/iK7xEz5 github.com/gobuffalo/syncx v0.0.0-20190224160051-33c29581e754/go.mod h1:HhnNqWY95UYwwW3uSASeV7vtgYkT2t16hJgV3AEPUpw= github.com/gobwas/glob v0.2.3/go.mod h1:d3Ez4x06l9bZtSvzIay5+Yzi0fmZzPgnTbPcKjJAkT8= github.com/gocql/gocql v0.0.0-20210515062232-b7ef815b4556/go.mod h1:DL0ekTmBSTdlNF25Orwt/JMzqIq3EJ4MVa/J/uK64OY= +github.com/gocarina/gocsv v0.0.0-20210516172204-ca9e8a8ddea8 h1:hp1oqdzmv37vPLYFGjuM/RmUgUMfD9vQfMszc54l55Y= +github.com/gocarina/gocsv v0.0.0-20210516172204-ca9e8a8ddea8/go.mod h1:5YoVOkjYAQumqlV356Hj3xeYh4BdZuLE0/nRkf2NKkI= +github.com/gocql/gocql v0.0.0-20190301043612-f6df8288f9b4/go.mod h1:4Fw1eo5iaEhDUs8XyuhSVCVy52Jq3L+/3GJgYkwc+/0= github.com/gocraft/work v0.5.1 h1:3bRjMiOo6N4zcRgZWV3Y7uX7R22SF+A9bPTk4xRXr34= github.com/gocraft/work v0.5.1/go.mod h1:pc3n9Pb5FAESPPGfM0nL+7Q1xtgtRnF8rr/azzhQVlM= github.com/godbus/dbus v0.0.0-20151105175453-c7fdd8b5cd55/go.mod h1:/YcGZj5zSblfDWMMoOzV4fas9FZnQYTkDnsGvmh2Grw= diff --git a/src/jobservice/job/impl/scandataexport/scan_data_export.go b/src/jobservice/job/impl/scandataexport/scan_data_export.go new file mode 100644 index 000000000..313933131 --- /dev/null +++ b/src/jobservice/job/impl/scandataexport/scan_data_export.go @@ -0,0 +1,341 @@ +package scandataexport + +import ( + "context" + "encoding/json" + "fmt" + "os" + "path/filepath" + "strings" + + "github.com/gocarina/gocsv" + "github.com/goharbor/harbor/src/jobservice/job" + "github.com/goharbor/harbor/src/jobservice/logger" + "github.com/goharbor/harbor/src/lib/errors" + "github.com/goharbor/harbor/src/pkg/project" + "github.com/goharbor/harbor/src/pkg/scan/export" + "github.com/goharbor/harbor/src/pkg/systemartifact" + "github.com/goharbor/harbor/src/pkg/systemartifact/model" + "github.com/goharbor/harbor/src/pkg/task" + "github.com/opencontainers/go-digest" +) + +// ScanDataExport is the struct to implement the scan data export. +// implements the Job interface +type ScanDataExport struct { + execMgr task.ExecutionManager + scanDataExportDirPath string + exportMgr export.Manager + digestCalculator export.ArtifactDigestCalculator + filterProcessor export.FilterProcessor + vulnDataSelector export.VulnerabilityDataSelector + projectMgr project.Manager + sysArtifactMgr systemartifact.Manager +} + +func (sde *ScanDataExport) MaxFails() uint { + return 1 +} + +// MaxCurrency of the job. Unlike the WorkerPool concurrency, it controls the limit on the number jobs of that type +// that can be active at one time by within a single redis instance. +// The default value is 0, which means "no limit on job concurrency". +func (sde *ScanDataExport) MaxCurrency() uint { + return 1 +} + +// ShouldRetry tells worker if retry the failed job when the fails is +// still less that the number declared by the method 'MaxFails'. +// +// Returns: +// true for retry and false for none-retry +func (sde *ScanDataExport) ShouldRetry() bool { + return true +} + +// Validate Indicate whether the parameters of job are valid. +// Return: +// error if parameters are not valid. NOTES: If no parameters needed, directly return nil. +func (sde *ScanDataExport) Validate(params job.Parameters) error { + return nil +} + +// Run the business logic here. +// The related arguments will be injected by the workerpool. +// +// ctx Context : Job execution context. +// params map[string]interface{} : parameters with key-pair style for the job execution. +// +// Returns: +// error if failed to run. NOTES: If job is stopped or cancelled, a specified error should be returned +// +func (sde *ScanDataExport) Run(ctx job.Context, params job.Parameters) error { + if _, ok := params[export.JobModeKey]; !ok { + return errors.Errorf("no mode specified for scan data export execution") + } + + mode := params[export.JobModeKey].(string) + logger.Infof("Scan data export job started in mode : %v", mode) + sde.init() + fileName := fmt.Sprintf("%s/scandata_export_%v.csv", sde.scanDataExportDirPath, params["JobId"]) + + // ensure that CSV files are cleared post the completion of the Run. + defer sde.cleanupCsvFile(fileName, params) + err := sde.writeCsvFile(ctx, params, fileName) + if err != nil { + logger.Errorf("error when writing data to CSV: %v", err) + return err + } + + hash, err := sde.calculateFileHash(fileName) + if err != nil { + logger.Errorf("Error when calculating checksum for generated file: %v", err) + return err + } + logger.Infof("Export Job Id = %v, FileName = %s, Hash = %v", params["JobId"], fileName, hash) + + csvFile, err := os.OpenFile(fileName, os.O_RDONLY, os.ModePerm) + if err != nil { + logger.Errorf( + "Export Job Id = %v. Error when moving report file %s to persistent storage: %v", params["JobId"], fileName, err) + return err + } + baseFileName := filepath.Base(fileName) + repositoryName := strings.TrimSuffix(baseFileName, filepath.Ext(baseFileName)) + logger.Infof("Creating repository for CSV file with blob : %s", repositoryName) + stat, err := os.Stat(fileName) + if err != nil { + logger.Errorf("Error when fetching file size: %v", err) + return err + } + logger.Infof("Export Job Id = %v. CSV file size: %d", params["JobId"], stat.Size()) + csvExportArtifactRecord := model.SystemArtifact{Repository: repositoryName, Digest: hash.String(), Size: stat.Size(), Type: "ScanData_CSV", Vendor: strings.ToLower(export.Vendor)} + artID, err := sde.sysArtifactMgr.Create(ctx.SystemContext(), &csvExportArtifactRecord, csvFile) + if err != nil { + logger.Errorf( + "Export Job Id = %v. Error when persisting report file %s to persistent storage: %v", params["JobId"], fileName, err) + return err + } + + logger.Infof("Export Job Id = %v. Created system artifact: %v for report file %s to persistent storage: %v", params["JobId"], artID, fileName, err) + err = sde.updateExecAttributes(ctx, params, err, hash) + + if err != nil { + logger.Errorf("Export Job Id = %v. Error when updating execution record : %v", params["JobId"], err) + return err + } + logger.Info("Scan data export job completed") + + return nil +} + +func (sde *ScanDataExport) updateExecAttributes(ctx job.Context, params job.Parameters, err error, hash digest.Digest) error { + execID := int64(params["JobId"].(float64)) + exec, err := sde.execMgr.Get(ctx.SystemContext(), execID) + if err != nil { + logger.Errorf("Export Job Id = %v. Error when fetching execution record for update : %v", params["JobId"], err) + return err + } + attrsToUpdate := make(map[string]interface{}) + for k, v := range exec.ExtraAttrs { + attrsToUpdate[k] = v + } + attrsToUpdate[export.DigestKey] = hash.String() + return sde.execMgr.UpdateExtraAttrs(ctx.SystemContext(), execID, attrsToUpdate) +} + +func (sde *ScanDataExport) writeCsvFile(ctx job.Context, params job.Parameters, fileName string) error { + csvFile, err := os.OpenFile(fileName, os.O_WRONLY|os.O_CREATE|os.O_APPEND, os.ModePerm) + if err != nil { + return err + } + systemContext := ctx.SystemContext() + defer csvFile.Close() + + if err != nil { + logger.Errorf("Failed to create CSV export file %s. Error : %v", fileName, err) + return err + } + logger.Infof("Created CSV export file %s", csvFile.Name()) + + var exportParams export.Params + + if criteira, ok := params["Request"]; ok { + logger.Infof("Request for export : %v", criteira) + filterCriteria, err := sde.extractCriteria(params) + if err != nil { + return err + } + + // check if any projects are specified. If not then fetch all the projects + // of which the current user is a project admin. + projectIds, err := sde.filterProcessor.ProcessProjectFilter(systemContext, filterCriteria.UserName, filterCriteria.Projects) + + if err != nil { + return err + } + + if len(projectIds) == 0 { + return nil + } + + // extract the repository ids if any repositories have been specified + repoIds, err := sde.getRepositoryIds(systemContext, filterCriteria.Repositories, projectIds) + if err != nil { + return err + } + + if len(repoIds) == 0 { + logger.Infof("No repositories found with specified names: %v", filterCriteria.Repositories) + return nil + } + + // filter the specified repositories further using the specified tags + repoIds, err = sde.getRepositoriesWithTags(systemContext, filterCriteria, repoIds) + + if err != nil { + return err + } + + if len(repoIds) == 0 { + logger.Infof("No repositories found with specified names: %v and tags: %v", filterCriteria.Repositories, filterCriteria.Tags) + return nil + } + + exportParams = export.Params{ + Projects: filterCriteria.Projects, + Repositories: repoIds, + CVEIds: filterCriteria.CVEIds, + Labels: filterCriteria.Labels, + } + } + + exportParams.PageNumber = 1 + exportParams.PageSize = export.QueryPageSize + + for { + data, err := sde.exportMgr.Fetch(systemContext, exportParams) + if err != nil { + logger.Error("Encountered error reading from the report table", err) + return err + } + if len(data) == 0 { + logger.Infof("No more data to fetch. Exiting...") + break + } + logger.Infof("Export Job Id = %v, Page Number = %d, Page Size = %d Num Records = %d", params["JobId"], exportParams.PageNumber, exportParams.PageSize, len(data)) + + // for the first page write the CSV with the headers + if exportParams.PageNumber == 1 { + err = gocsv.Marshal(data, csvFile) + } else { + err = gocsv.MarshalWithoutHeaders(data, csvFile) + } + if err != nil { + return nil + } + exportParams.PageNumber = exportParams.PageNumber + 1 + } + return nil +} + +func (sde *ScanDataExport) getRepositoryIds(ctx context.Context, filter string, projectIds []int64) ([]int64, error) { + repositoryIds := make([]int64, 0) + candidates, err := sde.filterProcessor.ProcessRepositoryFilter(ctx, filter, projectIds) + if err != nil { + return nil, err + } + if candidates == nil { + return repositoryIds, nil + } + for _, cand := range candidates { + repositoryIds = append(repositoryIds, cand.NamespaceID) + } + return repositoryIds, nil +} + +func (sde *ScanDataExport) getRepositoriesWithTags(ctx context.Context, filterCriteria *export.Request, repositoryIds []int64) ([]int64, error) { + if filterCriteria.Tags == "" { + return repositoryIds, nil + } + candidates, err := sde.filterProcessor.ProcessTagFilter(ctx, filterCriteria.Tags, repositoryIds) + if err != nil { + return nil, err + } + if candidates == nil { + return make([]int64, 0), nil + } + filteredCandidates := make([]int64, 0) + for _, cand := range candidates { + filteredCandidates = append(filteredCandidates, cand.NamespaceID) + } + return filteredCandidates, nil +} + +func (sde *ScanDataExport) extractCriteria(params job.Parameters) (*export.Request, error) { + filterMap, ok := params["Request"].(map[string]interface{}) + if !ok { + return nil, errors.Errorf("malformed criteria '%v'", params["Request"]) + } + jsonData, err := json.Marshal(filterMap) + if err != nil { + return nil, err + } + criteria := &export.Request{} + err = criteria.FromJSON(string(jsonData)) + + if err != nil { + return nil, err + } + return criteria, nil +} + +func (sde *ScanDataExport) calculateFileHash(fileName string) (digest.Digest, error) { + return sde.digestCalculator.Calculate(fileName) +} + +func (sde *ScanDataExport) init() { + if sde.execMgr == nil { + sde.execMgr = task.NewExecutionManager() + } + + if sde.scanDataExportDirPath == "" { + sde.scanDataExportDirPath = export.ScanDataExportDir + } + + if sde.exportMgr == nil { + sde.exportMgr = export.NewManager() + } + + if sde.digestCalculator == nil { + sde.digestCalculator = &export.SHA256ArtifactDigestCalculator{} + } + + if sde.filterProcessor == nil { + sde.filterProcessor = export.NewFilterProcessor() + } + + if sde.vulnDataSelector == nil { + sde.vulnDataSelector = export.NewVulnerabilityDataSelector() + } + + if sde.projectMgr == nil { + sde.projectMgr = project.New() + } + + if sde.sysArtifactMgr == nil { + sde.sysArtifactMgr = systemartifact.Mgr + } +} + +func (sde *ScanDataExport) cleanupCsvFile(fileName string, params job.Parameters) { + if _, err := os.Stat(fileName); os.IsNotExist(err) { + logger.Infof("Export Job Id = %v, CSV Export File = %s does not exist. Nothing to do", params["JobId"], fileName) + return + } + err := os.Remove(fileName) + if err != nil { + logger.Errorf("Export Job Id = %d, CSV Export File = %s could not deleted. Error = %v", params["JobId"], fileName, err) + return + } +} diff --git a/src/jobservice/job/impl/scandataexport/scan_data_export_test.go b/src/jobservice/job/impl/scandataexport/scan_data_export_test.go new file mode 100644 index 000000000..b44e04fea --- /dev/null +++ b/src/jobservice/job/impl/scandataexport/scan_data_export_test.go @@ -0,0 +1,649 @@ +package scandataexport + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "os" + "reflect" + "strconv" + "strings" + + "github.com/goharbor/harbor/src/jobservice/job" + "github.com/goharbor/harbor/src/lib/selector" + "github.com/goharbor/harbor/src/pkg/scan/export" + "github.com/goharbor/harbor/src/pkg/systemartifact/model" + "github.com/goharbor/harbor/src/pkg/task" + htesting "github.com/goharbor/harbor/src/testing" + mockjobservice "github.com/goharbor/harbor/src/testing/jobservice" + "github.com/goharbor/harbor/src/testing/mock" + "github.com/goharbor/harbor/src/testing/pkg/project" + export2 "github.com/goharbor/harbor/src/testing/pkg/scan/export" + systemartifacttesting "github.com/goharbor/harbor/src/testing/pkg/systemartifact" + tasktesting "github.com/goharbor/harbor/src/testing/pkg/task" + "github.com/opencontainers/go-digest" + testifymock "github.com/stretchr/testify/mock" + + "testing" + + "github.com/stretchr/testify/suite" +) + +const JobId = float64(100) +const MockDigest = "mockDigest" + +type ScanDataExportJobTestSuite struct { + htesting.Suite + execMgr *tasktesting.ExecutionManager + job *ScanDataExport + exportMgr *export2.Manager + digestCalculator *export2.ArtifactDigestCalculator + filterProcessor *export2.FilterProcessor + projectMgr *project.Manager + sysArtifactMgr *systemartifacttesting.Manager +} + +func (suite *ScanDataExportJobTestSuite) SetupSuite() { +} + +func (suite *ScanDataExportJobTestSuite) SetupTest() { + suite.execMgr = &tasktesting.ExecutionManager{} + suite.exportMgr = &export2.Manager{} + suite.digestCalculator = &export2.ArtifactDigestCalculator{} + suite.filterProcessor = &export2.FilterProcessor{} + suite.projectMgr = &project.Manager{} + suite.sysArtifactMgr = &systemartifacttesting.Manager{} + suite.job = &ScanDataExport{ + execMgr: suite.execMgr, + exportMgr: suite.exportMgr, + scanDataExportDirPath: "/tmp", + digestCalculator: suite.digestCalculator, + filterProcessor: suite.filterProcessor, + sysArtifactMgr: suite.sysArtifactMgr, + } + + suite.execMgr.On("UpdateExtraAttrs", mock.Anything, mock.Anything, mock.Anything).Return(nil) + // all BLOB related operations succeed + suite.sysArtifactMgr.On("Create", mock.Anything, mock.Anything, mock.Anything).Return(int64(1), nil) +} + +func (suite *ScanDataExportJobTestSuite) TestRun() { + + data := suite.createDataRecords(3, 1) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil) + + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + ctx := &mockjobservice.MockJobContext{} + + err := suite.job.Run(ctx, params) + suite.NoError(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + + m := make(map[string]interface{}) + m[export.DigestKey] = MockDigest + m[export.CreateTimestampKey] = mock.Anything + + extraAttrsMatcher := testifymock.MatchedBy(func(attrsMap map[string]interface{}) bool { + _, ok := m[export.CreateTimestampKey] + return attrsMap[export.DigestKey] == MockDigest && ok && attrsMap[export.JobNameAttribute] == "test-job" && attrsMap[export.UserNameAttribute] == "test-user" + }) + suite.execMgr.AssertCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), extraAttrsMatcher) + _, err = os.Stat("/tmp/scandata_export_100.csv") + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + +} + +func (suite *ScanDataExportJobTestSuite) TestRunAttributeUpdateError() { + + data := suite.createDataRecords(3, 1) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(nil, errors.New("test-error")) + + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + ctx := &mockjobservice.MockJobContext{} + + err := suite.job.Run(ctx, params) + suite.Error(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + + m := make(map[string]interface{}) + m[export.DigestKey] = MockDigest + m[export.CreateTimestampKey] = mock.Anything + + extraAttrsMatcher := testifymock.MatchedBy(func(attrsMap map[string]interface{}) bool { + _, ok := m[export.CreateTimestampKey] + return attrsMap[export.DigestKey] == MockDigest && ok && attrsMap[export.JobNameAttribute] == "test-job" && attrsMap[export.UserNameAttribute] == "test-user" + }) + suite.execMgr.AssertNotCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), extraAttrsMatcher) + _, err = os.Stat("/tmp/scandata_export_100.csv") + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + +} + +func (suite *ScanDataExportJobTestSuite) TestRunWithCriteria() { + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidate1 := &selector.Candidate{NamespaceID: 1} + repoCandidates := []*selector.Candidate{repoCandidate1} + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{1}, nil).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(repoCandidates, nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.NoError(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + + m := make(map[string]interface{}) + m[export.DigestKey] = MockDigest + m[export.CreateTimestampKey] = mock.Anything + + extraAttrsMatcher := testifymock.MatchedBy(func(attrsMap map[string]interface{}) bool { + _, ok := m[export.CreateTimestampKey] + return attrsMap[export.DigestKey] == MockDigest && ok + }) + suite.execMgr.AssertCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), extraAttrsMatcher) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + exportParamsMatcher := testifymock.MatchedBy(func(params export.Params) bool { + return reflect.DeepEqual(params.Labels, criteria.Labels) && reflect.DeepEqual(params.CVEIds, criteria.CVEIds) && reflect.DeepEqual(params.Repositories, []int64{1}) && reflect.DeepEqual(params.Projects, criteria.Projects) + }) + suite.exportMgr.AssertCalled(suite.T(), "Fetch", mock.Anything, exportParamsMatcher) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + + { + mock.OnAnything(suite.sysArtifactMgr, "Create").Return(int64(1), nil).Once() + data := suite.createDataRecords(3, 1) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidate1 := &selector.Candidate{NamespaceID: 1} + repoCandidates := []*selector.Candidate{repoCandidate1} + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{1}, nil).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(repoCandidates, nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Times(3) + + err := suite.job.Run(ctx, params) + suite.NoError(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + m := make(map[string]interface{}) + m[export.DigestKey] = MockDigest + m[export.CreateTimestampKey] = mock.Anything + + extraAttrsMatcher := testifymock.MatchedBy(func(attrsMap map[string]interface{}) bool { + _, ok := m[export.CreateTimestampKey] + return attrsMap[export.DigestKey] == MockDigest && ok + }) + suite.execMgr.AssertCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), extraAttrsMatcher) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + exportParamsMatcher := testifymock.MatchedBy(func(params export.Params) bool { + return reflect.DeepEqual(params.Labels, criteria.Labels) && reflect.DeepEqual(params.CVEIds, criteria.CVEIds) && reflect.DeepEqual(params.Repositories, []int64{1}) && reflect.DeepEqual(params.Projects, criteria.Projects) + }) + suite.exportMgr.AssertCalled(suite.T(), "Fetch", mock.Anything, exportParamsMatcher) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } +} + +func (suite *ScanDataExportJobTestSuite) TestRunWithCriteriaForProjectIdFilter() { + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidate1 := &selector.Candidate{NamespaceID: 1} + repoCandidates := []*selector.Candidate{repoCandidate1} + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return(nil, errors.New("test error")).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(repoCandidates, nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.Error(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertNotCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + suite.execMgr.AssertNotCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), mock.Anything) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + suite.exportMgr.AssertNotCalled(suite.T(), "Fetch", mock.Anything, mock.Anything) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + + // empty list of projects + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidate1 := &selector.Candidate{NamespaceID: 1} + repoCandidates := []*selector.Candidate{repoCandidate1} + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{}, nil).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(repoCandidates, nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.NoError(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + + suite.execMgr.AssertCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), mock.Anything) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + suite.exportMgr.AssertNotCalled(suite.T(), "Fetch", mock.Anything, mock.Anything) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + +} + +func (suite *ScanDataExportJobTestSuite) TestRunWithCriteriaForRepositoryIdFilter() { + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidate1 := &selector.Candidate{NamespaceID: 1} + repoCandidates := []*selector.Candidate{repoCandidate1} + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{1}, errors.New("test error")).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(nil, errors.New("test error")) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(repoCandidates, nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.Error(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertNotCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + suite.execMgr.AssertNotCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), mock.Anything) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + suite.exportMgr.AssertNotCalled(suite.T(), "Fetch", mock.Anything, mock.Anything) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + + // empty list of repo ids + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidates := make([]*selector.Candidate, 0) + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{}, nil).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(repoCandidates, nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.NoError(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + suite.execMgr.AssertCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), mock.Anything) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + suite.exportMgr.AssertNotCalled(suite.T(), "Fetch", mock.Anything, mock.Anything) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + +} + +func (suite *ScanDataExportJobTestSuite) TestRunWithCriteriaForRepositoryIdWithTagFilter() { + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidate1 := &selector.Candidate{NamespaceID: 1} + repoCandidates := []*selector.Candidate{repoCandidate1} + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{1}, errors.New("test error")).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(nil, errors.New("test error")) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.Error(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertNotCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + suite.execMgr.AssertNotCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), mock.Anything) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + suite.exportMgr.AssertNotCalled(suite.T(), "Fetch", mock.Anything, mock.Anything) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + + // empty list of repo ids after applying tag filters + { + data := suite.createDataRecords(3, 1) + + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(MockDigest), nil) + execAttrs := make(map[string]interface{}) + execAttrs[export.JobNameAttribute] = "test-job" + execAttrs[export.UserNameAttribute] = "test-user" + mock.OnAnything(suite.execMgr, "Get").Return(&task.Execution{ID: int64(JobId), ExtraAttrs: execAttrs}, nil).Once() + + repoCandidates := make([]*selector.Candidate, 0) + mock.OnAnything(suite.filterProcessor, "ProcessProjectFilter").Return([]int64{}, nil).Once() + mock.OnAnything(suite.filterProcessor, "ProcessRepositoryFilter").Return(repoCandidates, nil) + mock.OnAnything(suite.filterProcessor, "ProcessTagFilter").Return(make([]*selector.Candidate, 0), nil) + + criteria := export.Request{ + CVEIds: "CVE-123", + Labels: []int64{1}, + Projects: []int64{1}, + Repositories: "test-repo", + Tags: "test-tag", + } + criteriaMap := make(map[string]interface{}) + bytes, _ := json.Marshal(criteria) + json.Unmarshal(bytes, &criteriaMap) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + params["Request"] = criteriaMap + + ctx := &mockjobservice.MockJobContext{} + ctx.On("SystemContext").Return(context.TODO()).Once() + + err := suite.job.Run(ctx, params) + suite.NoError(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + suite.execMgr.AssertCalled(suite.T(), "UpdateExtraAttrs", mock.Anything, int64(JobId), mock.Anything) + _, err = os.Stat("/tmp/scandata_export_100.csv") + + suite.exportMgr.AssertNotCalled(suite.T(), "Fetch", mock.Anything, mock.Anything) + + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") + } + +} + +func (suite *ScanDataExportJobTestSuite) TestExportDigestCalculationErrorsOut() { + data := suite.createDataRecords(3, 1) + mock.OnAnything(suite.exportMgr, "Fetch").Return(data, nil).Once() + mock.OnAnything(suite.exportMgr, "Fetch").Return(make([]export.Data, 0), nil).Once() + mock.OnAnything(suite.digestCalculator, "Calculate").Return(digest.Digest(""), errors.New("test error")) + params := job.Parameters{} + params[export.JobModeKey] = export.JobModeExport + params["JobId"] = JobId + ctx := &mockjobservice.MockJobContext{} + + err := suite.job.Run(ctx, params) + suite.Error(err) + sysArtifactRecordMatcher := testifymock.MatchedBy(func(sa *model.SystemArtifact) bool { + return sa.Repository == "scandata_export_100" && sa.Vendor == strings.ToLower(export.Vendor) && sa.Digest == MockDigest + }) + suite.sysArtifactMgr.AssertNotCalled(suite.T(), "Create", mock.Anything, sysArtifactRecordMatcher, mock.Anything) + suite.execMgr.AssertNotCalled(suite.T(), "UpdateExtraAttrs") + _, err = os.Stat("/tmp/scandata_export_100.csv") + suite.Truef(os.IsNotExist(err), "Expected CSV file to be deleted") +} + +func (suite *ScanDataExportJobTestSuite) TearDownTest() { + path := fmt.Sprintf("/tmp/scandata_export_%v.csv", JobId) + _, err := os.Stat(path) + if os.IsNotExist(err) { + return + } + err = os.Remove(path) + suite.NoError(err) +} + +func (suite *ScanDataExportJobTestSuite) createDataRecords(numRecs int, ownerId int64) []export.Data { + data := make([]export.Data, 0) + for i := 1; i <= numRecs; i++ { + dataRec := export.Data{ + ID: int64(i), + ProjectName: fmt.Sprintf("TestProject%d", i), + ProjectOwner: strconv.FormatInt(ownerId, 10), + ScannerName: fmt.Sprintf("TestScanner%d", i), + CVEId: fmt.Sprintf("CVEId-%d", i), + Package: fmt.Sprintf("Package%d", i), + Severity: fmt.Sprintf("Severity%d", i), + CVSSScoreV3: fmt.Sprintf("3.0"), + CVSSScoreV2: fmt.Sprintf("2.0"), + CVSSVectorV3: fmt.Sprintf("TestCVSSVectorV3%d", i), + CVSSVectorV2: fmt.Sprintf("TestCVSSVectorV2%d", i), + CWEIds: "", + } + data = append(data, dataRec) + } + return data +} +func TestScanDataExportJobSuite(t *testing.T) { + suite.Run(t, &ScanDataExportJobTestSuite{}) +} diff --git a/src/jobservice/job/known_jobs.go b/src/jobservice/job/known_jobs.go index 69710bef9..c1542c667 100644 --- a/src/jobservice/job/known_jobs.go +++ b/src/jobservice/job/known_jobs.go @@ -38,4 +38,6 @@ const ( PurgeAudit = "PURGE_AUDIT" // SystemArtifactCleanup : the name of the SystemArtifact cleanup job SystemArtifactCleanup = "SYSTEM_ARTIFACT_CLEANUP" + // ScanDataExport : the name of the scan data export job + ScanDataExport = "SCAN_DATA_EXPORT" ) diff --git a/src/jobservice/runtime/bootstrap.go b/src/jobservice/runtime/bootstrap.go index f0fe34c01..75b403cef 100644 --- a/src/jobservice/runtime/bootstrap.go +++ b/src/jobservice/runtime/bootstrap.go @@ -17,6 +17,7 @@ package runtime import ( "context" "fmt" + "github.com/goharbor/harbor/src/jobservice/job/impl/scandataexport" "os" "os/signal" "strings" @@ -319,6 +320,7 @@ func (bs *Bootstrap) loadAndRunRedisWorkerPool( job.WebhookJob: (*notification.WebhookJob)(nil), job.SlackJob: (*notification.SlackJob)(nil), job.P2PPreheat: (*preheat.Job)(nil), + job.ScanDataExport: (*scandataexport.ScanDataExport)(nil), // In v2.2 we migrate the scheduled replication, garbage collection and scan all to // the scheduler mechanism, the following three jobs are kept for the legacy jobs // and they can be removed after several releases diff --git a/src/pkg/scan/export/constants.go b/src/pkg/scan/export/constants.go new file mode 100644 index 000000000..b1bdd7d3a --- /dev/null +++ b/src/pkg/scan/export/constants.go @@ -0,0 +1,15 @@ +package export + +// CsvJobVendorID specific type to be used in contexts +type CsvJobVendorID string + +const ( + JobNameAttribute = "job_name" + UserNameAttribute = "user_name" + ScanDataExportDir = "/var/scandata_exports" + QueryPageSize = 100 + DigestKey = "artifact_digest" + CreateTimestampKey = "create_ts" + Vendor = "SCAN_DATA_EXPORT" + CsvJobVendorIDKey = CsvJobVendorID("vendorId") +) diff --git a/src/pkg/scan/export/digest_calculator.go b/src/pkg/scan/export/digest_calculator.go new file mode 100644 index 000000000..1469b9178 --- /dev/null +++ b/src/pkg/scan/export/digest_calculator.go @@ -0,0 +1,28 @@ +package export + +import ( + "crypto/sha256" + "github.com/opencontainers/go-digest" + "io" + "os" +) + +// ArtifactDigestCalculator is an interface to be implemented by all file hash calculators +type ArtifactDigestCalculator interface { + // Calculate returns the hash for a file + Calculate(fileName string) (digest.Digest, error) +} + +type SHA256ArtifactDigestCalculator struct{} + +func (calc *SHA256ArtifactDigestCalculator) Calculate(fileName string) (digest.Digest, error) { + file, err := os.OpenFile(fileName, os.O_RDONLY, os.ModePerm) + if err != nil { + return "", err + } + hash := sha256.New() + if _, err := io.Copy(hash, file); err != nil { + return "", err + } + return digest.NewDigest(digest.SHA256, hash), nil +} diff --git a/src/pkg/scan/export/digest_calculator_test.go b/src/pkg/scan/export/digest_calculator_test.go new file mode 100644 index 000000000..36da1788d --- /dev/null +++ b/src/pkg/scan/export/digest_calculator_test.go @@ -0,0 +1,32 @@ +package export + +import ( + "crypto/sha256" + "fmt" + "github.com/stretchr/testify/suite" + "io/ioutil" + "os" + "testing" +) + +type DigestCalculatorTestSuite struct { + suite.Suite +} + +func (suite *DigestCalculatorTestSuite) TestDigestCalculation() { + fileName := "/tmp/testfile.txt" + data := []byte("test") + ioutil.WriteFile(fileName, data, os.ModePerm) + digestCalc := SHA256ArtifactDigestCalculator{} + digest, err := digestCalc.Calculate(fileName) + suite.NoError(err) + + hasher := sha256.New() + hasher.Write(data) + expectedDigest := fmt.Sprintf("sha256:%x", hasher.Sum(nil)) + suite.Equal(expectedDigest, digest.String()) +} + +func TestDigestCalculatorTestSuite(t *testing.T) { + suite.Run(t, &DigestCalculatorTestSuite{}) +} diff --git a/src/pkg/scan/export/export_data_selector.go b/src/pkg/scan/export/export_data_selector.go new file mode 100644 index 000000000..e96c483f6 --- /dev/null +++ b/src/pkg/scan/export/export_data_selector.go @@ -0,0 +1,60 @@ +package export + +import "github.com/bmatcuk/doublestar" + +const ( + CVEIDMatches = "cveIdMatches" + PackageMatches = "packageMatches" + ScannerMatches = "scannerMatches" + CVE2VectorMatches = "cve2VectorMatches" + CVE3VectorMatches = "cve3VectorMatches" +) + +// VulnerabilityDataSelector is a specialized implementation of a selector +// leveraging the doublestar pattern to select vulnerabilities +type VulnerabilityDataSelector interface { + Select(vulnDataRecords []Data, decoration string, pattern string) ([]Data, error) +} + +type defaultVulnerabilitySelector struct{} + +// NewVulnerabilityDataSelector selects the vulnerability data record +// that matches the provided conditions +func NewVulnerabilityDataSelector() VulnerabilityDataSelector { + return &defaultVulnerabilitySelector{} +} + +func (vds *defaultVulnerabilitySelector) Select(vulnDataRecords []Data, decoration string, pattern string) ([]Data, error) { + selected := make([]Data, 0) + value := "" + + for _, vulnDataRecord := range vulnDataRecords { + switch decoration { + case CVEIDMatches: + value = vulnDataRecord.CVEId + case PackageMatches: + value = vulnDataRecord.Package + case ScannerMatches: + value = vulnDataRecord.ScannerName + case CVE2VectorMatches: + value = vulnDataRecord.CVSSVectorV2 + case CVE3VectorMatches: + value = vulnDataRecord.CVSSVectorV3 + } + matched, err := vds.match(pattern, value) + if err != nil { + return nil, err + } + if matched { + selected = append(selected, vulnDataRecord) + } + } + return selected, nil +} + +func (vds *defaultVulnerabilitySelector) match(pattern, str string) (bool, error) { + if len(pattern) == 0 { + return true, nil + } + return doublestar.Match(pattern, str) +} diff --git a/src/pkg/scan/export/export_data_selector_test.go b/src/pkg/scan/export/export_data_selector_test.go new file mode 100644 index 000000000..a19ea51d6 --- /dev/null +++ b/src/pkg/scan/export/export_data_selector_test.go @@ -0,0 +1,123 @@ +package export + +import ( + "fmt" + "github.com/stretchr/testify/suite" + "strconv" + "testing" +) + +type ExportDataSelectorTestSuite struct { + suite.Suite + exportDataSelector VulnerabilityDataSelector +} + +func (suite *ExportDataSelectorTestSuite) SetupSuite() { + suite.exportDataSelector = NewVulnerabilityDataSelector() +} + +func (suite *ExportDataSelectorTestSuite) TestCVEFilter() { + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, CVEIDMatches, "CVEId-1") + suite.NoError(err) + suite.Equal(1, len(filtered)) + suite.Equal("CVEId-1", filtered[0].CVEId) + } + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, CVEIDMatches, "") + suite.NoError(err) + suite.Equal(10, len(filtered)) + } +} + +func (suite *ExportDataSelectorTestSuite) TestPackageFilter() { + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, PackageMatches, "Package1") + suite.NoError(err) + suite.Equal(1, len(filtered)) + suite.Equal("Package1", filtered[0].Package) + } + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, PackageMatches, "") + suite.NoError(err) + suite.Equal(10, len(filtered)) + } +} + +func (suite *ExportDataSelectorTestSuite) TestScannerNameFilter() { + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, ScannerMatches, "TestScanner1") + suite.NoError(err) + suite.Equal(1, len(filtered)) + suite.Equal("TestScanner1", filtered[0].ScannerName) + } + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, ScannerMatches, "") + suite.NoError(err) + suite.Equal(10, len(filtered)) + } +} + +func (suite *ExportDataSelectorTestSuite) TestCVE2VectorMatches() { + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, CVE2VectorMatches, "TestCVSSVectorV21") + suite.NoError(err) + suite.Equal(1, len(filtered)) + suite.Equal("TestCVSSVectorV21", filtered[0].CVSSVectorV2) + } + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, CVE2VectorMatches, "") + suite.NoError(err) + suite.Equal(10, len(filtered)) + } +} + +func (suite *ExportDataSelectorTestSuite) TestCVE3VectorMatches() { + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, CVE3VectorMatches, "TestCVSSVectorV31") + suite.NoError(err) + suite.Equal(1, len(filtered)) + suite.Equal("TestCVSSVectorV31", filtered[0].CVSSVectorV3) + } + { + dataRecords := suite.createDataRecords(10, 1) + filtered, err := suite.exportDataSelector.Select(dataRecords, CVE3VectorMatches, "") + suite.NoError(err) + suite.Equal(10, len(filtered)) + } +} + +func TestExportDataSelectorTestSuite(t *testing.T) { + suite.Run(t, &ExportDataSelectorTestSuite{}) +} + +func (suite *ExportDataSelectorTestSuite) createDataRecords(numRecs int, ownerId int64) []Data { + data := make([]Data, 0) + for i := 1; i <= numRecs; i++ { + dataRec := Data{ + ID: int64(i), + ProjectName: fmt.Sprintf("TestProject%d", i), + ProjectOwner: strconv.FormatInt(ownerId, 10), + ScannerName: fmt.Sprintf("TestScanner%d", i), + CVEId: fmt.Sprintf("CVEId-%d", i), + Package: fmt.Sprintf("Package%d", i), + Severity: fmt.Sprintf("Severity%d", i), + CVSSScoreV3: fmt.Sprintf("3.0"), + CVSSScoreV2: fmt.Sprintf("2.0"), + CVSSVectorV3: fmt.Sprintf("TestCVSSVectorV3%d", i), + CVSSVectorV2: fmt.Sprintf("TestCVSSVectorV2%d", i), + CWEIds: "", + } + data = append(data, dataRec) + } + return data +} diff --git a/src/pkg/scan/export/filter_processor.go b/src/pkg/scan/export/filter_processor.go new file mode 100644 index 000000000..6bb33879b --- /dev/null +++ b/src/pkg/scan/export/filter_processor.go @@ -0,0 +1,142 @@ +package export + +import ( + "context" + "github.com/goharbor/harbor/src/pkg" + + "github.com/goharbor/harbor/src/common" + commonmodels "github.com/goharbor/harbor/src/common/models" + "github.com/goharbor/harbor/src/common/security/local" + "github.com/goharbor/harbor/src/common/utils" + "github.com/goharbor/harbor/src/jobservice/logger" + "github.com/goharbor/harbor/src/lib/q" + "github.com/goharbor/harbor/src/lib/selector" + "github.com/goharbor/harbor/src/lib/selector/selectors/doublestar" + "github.com/goharbor/harbor/src/pkg/project" + "github.com/goharbor/harbor/src/pkg/project/models" + "github.com/goharbor/harbor/src/pkg/repository" + "github.com/goharbor/harbor/src/pkg/tag" + "github.com/goharbor/harbor/src/pkg/user" +) + +type FilterProcessor interface { + ProcessProjectFilter(ctx context.Context, userName string, projectsToFilter []int64) ([]int64, error) + ProcessRepositoryFilter(ctx context.Context, filter string, projectIds []int64) ([]*selector.Candidate, error) + ProcessTagFilter(ctx context.Context, filter string, repositoryIds []int64) ([]*selector.Candidate, error) +} + +type DefaultFilterProcessor struct { + repoMgr repository.Manager + tagMgr tag.Manager + usrMgr user.Manager + projectMgr project.Manager +} + +// NewFilterProcessor constructs an instance of a FilterProcessor +func NewFilterProcessor() FilterProcessor { + return &DefaultFilterProcessor{ + repoMgr: pkg.RepositoryMgr, + tagMgr: tag.Mgr, + usrMgr: user.Mgr, + projectMgr: pkg.ProjectMgr, + } +} + +func (dfp *DefaultFilterProcessor) ProcessProjectFilter(ctx context.Context, userName string, projectIdsToFilter []int64) ([]int64, error) { + // get the user id of the current user + + usr, err := dfp.usrMgr.GetByName(ctx, userName) + if err != nil { + return nil, err + } + logger.Infof("Retrieved user id :%d for user name : %s", usr.UserID, userName) + + if err != nil { + return nil, err + } + + query := dfp.getProjectQueryFilter(usr) + projects, err := dfp.projectMgr.List(ctx, query) + + if err != nil { + return nil, err + } + logger.Infof("Selected %d projects administered by user %s ", len(projects), userName) + projectIds := make([]int64, 0) + for _, proj := range projects { + projectIds = append(projectIds, proj.ProjectID) + } + + // check if the project ids specified in the filter are present in the list + // of projects of which the current user is a project admin + if len(projectIdsToFilter) == 0 { + return projectIds, nil + } + m := make(map[int64]bool) + for _, projectID := range projectIds { + m[projectID] = true + } + filtered := make([]int64, 0) + + for _, filteredProjID := range projectIdsToFilter { + if m[filteredProjID] { + filtered = append(filtered, filteredProjID) + } + } + return filtered, nil +} + +func (dfp *DefaultFilterProcessor) ProcessRepositoryFilter(ctx context.Context, filter string, projectIds []int64) ([]*selector.Candidate, error) { + sel := doublestar.New(doublestar.RepoMatches, filter, "") + candidates := make([]*selector.Candidate, 0) + + for _, projectID := range projectIds { + query := q.New(q.KeyWords{"ProjectID": projectID}) + allRepos, err := dfp.repoMgr.List(ctx, query) + if err != nil { + return nil, err + } + for _, repository := range allRepos { + namespace, repo := utils.ParseRepository(repository.Name) + candidates = append(candidates, &selector.Candidate{NamespaceID: repository.RepositoryID, Namespace: namespace, Repository: repo, Kind: "image"}) + } + } + // no repo filter specified then return all repos across all projects + if filter == "" { + return candidates, nil + } + return sel.Select(candidates) +} + +func (dfp *DefaultFilterProcessor) ProcessTagFilter(ctx context.Context, filter string, repositoryIds []int64) ([]*selector.Candidate, error) { + sel := doublestar.New(doublestar.Matches, filter, "") + candidates := make([]*selector.Candidate, 0) + + for _, repoID := range repositoryIds { + query := q.New(q.KeyWords{"RepositoryID": repoID}) + allTags, err := dfp.tagMgr.List(ctx, query) + if err != nil { + return nil, err + } + cand := &selector.Candidate{NamespaceID: repoID, Kind: "image"} + for _, tag := range allTags { + cand.Tags = append(cand.Tags, tag.Name) + } + candidates = append(candidates, cand) + } + // no tags specified then simply return all the candidates + if filter == "" { + return candidates, nil + } + return sel.Select(candidates) +} + +func (dfp *DefaultFilterProcessor) getProjectQueryFilter(user *commonmodels.User) *q.Query { + secContext := local.NewSecurityContext(user) + if secContext.IsSysAdmin() { + logger.Infof("User %v is sys admin. Selecting all projects for export.", user.Username) + return q.New(q.KeyWords{}) + } + logger.Infof("User %v is not sys admin. Selecting projects with admin roles for export.", user.Username) + return q.New(q.KeyWords{"member": &models.MemberQuery{UserID: user.UserID, Role: common.RoleProjectAdmin}}) +} diff --git a/src/pkg/scan/export/filter_processor_test.go b/src/pkg/scan/export/filter_processor_test.go new file mode 100644 index 000000000..74ea51cef --- /dev/null +++ b/src/pkg/scan/export/filter_processor_test.go @@ -0,0 +1,233 @@ +package export + +import ( + "context" + "errors" + commonmodels "github.com/goharbor/harbor/src/common/models" + "github.com/goharbor/harbor/src/lib/q" + "github.com/goharbor/harbor/src/pkg/project/models" + "github.com/goharbor/harbor/src/pkg/repository/model" + tag2 "github.com/goharbor/harbor/src/pkg/tag/model/tag" + "github.com/goharbor/harbor/src/testing/mock" + "github.com/goharbor/harbor/src/testing/pkg/project" + "github.com/goharbor/harbor/src/testing/pkg/repository" + "github.com/goharbor/harbor/src/testing/pkg/tag" + "github.com/goharbor/harbor/src/testing/pkg/user" + testifymock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "testing" + "time" +) + +type FilterProcessorTestSuite struct { + suite.Suite + repoMgr *repository.Manager + tagMgr *tag.FakeManager + usrMgr *user.Manager + projectMgr *project.Manager + filterProcessor FilterProcessor +} + +func (suite *FilterProcessorTestSuite) SetupSuite() { + +} + +func (suite *FilterProcessorTestSuite) SetupTest() { + suite.repoMgr = &repository.Manager{} + suite.tagMgr = &tag.FakeManager{} + suite.usrMgr = &user.Manager{} + suite.projectMgr = &project.Manager{} + suite.filterProcessor = &DefaultFilterProcessor{ + repoMgr: suite.repoMgr, + tagMgr: suite.tagMgr, + usrMgr: suite.usrMgr, + projectMgr: suite.projectMgr, + } +} + +func (suite *FilterProcessorTestSuite) TestProcessProjectFilter() { + project1 := &models.Project{ProjectID: 1} + + project2 := &models.Project{ProjectID: 2} + + // no filtered projects returns all projects + { + suite.usrMgr.On("GetByName", mock.Anything, "test-user").Return(&commonmodels.User{UserID: 1}, nil).Once() + suite.projectMgr.On("List", mock.Anything, mock.Anything).Return([]*models.Project{project1, project2}, nil).Once() + projectIds, err := suite.filterProcessor.ProcessProjectFilter(context.TODO(), "test-user", []int64{}) + suite.Equal(2, len(projectIds)) + suite.NoError(err) + } + + // filtered project + { + suite.usrMgr.On("GetByName", mock.Anything, "test-user").Return(&commonmodels.User{UserID: 1}, nil).Once() + suite.projectMgr.On("List", mock.Anything, mock.Anything).Return([]*models.Project{project1, project2}, nil).Once() + projectIds, err := suite.filterProcessor.ProcessProjectFilter(context.TODO(), "test-user", []int64{1}) + suite.Equal(1, len(projectIds)) + suite.Equal(int64(1), projectIds[0]) + suite.NoError(err) + } + + // project listing for admin user + { + suite.usrMgr.On("GetByName", mock.Anything, "test-user").Return(&commonmodels.User{UserID: 1, SysAdminFlag: true}, nil).Once() + suite.projectMgr.On("List", mock.Anything, mock.Anything).Return([]*models.Project{project1, project2}, nil).Once() + _, err := suite.filterProcessor.ProcessProjectFilter(context.TODO(), "test-user", []int64{1}) + suite.NoError(err) + queryArgumentMatcher := testifymock.MatchedBy(func(query *q.Query) bool { + return len(query.Keywords) == 0 + }) + suite.projectMgr.AssertCalled(suite.T(), "List", mock.Anything, queryArgumentMatcher) + } + + // project listing returns an error + // filtered project + { + suite.usrMgr.On("GetByName", mock.Anything, "test-user").Return(&commonmodels.User{UserID: 1}, nil).Once() + suite.projectMgr.On("List", mock.Anything, mock.Anything).Return(nil, errors.New("test-error")).Once() + projectIds, err := suite.filterProcessor.ProcessProjectFilter(context.TODO(), "test-user", []int64{1}) + suite.Error(err) + suite.Nil(projectIds) + } + +} + +func (suite *FilterProcessorTestSuite) TestProcessRepositoryFilter() { + + repoRecord1 := model.RepoRecord{ + RepositoryID: int64(1), + Name: "test/repo1", + ProjectID: int64(100), + Description: "test repo 1", + PullCount: 1, + StarCount: 4, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + } + repoRecord2 := model.RepoRecord{ + RepositoryID: int64(1), + Name: "test/repo2", + ProjectID: int64(100), + Description: "test repo 2", + PullCount: 1, + StarCount: 4, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + } + + allRepos := make([]*model.RepoRecord, 0) + allRepos = append(allRepos, &repoRecord1, &repoRecord2) + + // filter required repositories + { + suite.repoMgr.On("List", mock.Anything, mock.Anything).Return(allRepos, nil).Once() + candidates, err := suite.filterProcessor.ProcessRepositoryFilter(context.TODO(), "repo1", []int64{100}) + suite.NoError(err) + suite.Equal(1, len(candidates), "Expected 1 candidate but found ", len(candidates)) + suite.Equal("repo1", candidates[0].Repository) + } + + // simulate repo manager returning an error + { + suite.repoMgr.On("List", mock.Anything, mock.Anything).Return(nil, errors.New("test error")).Once() + candidates, err := suite.filterProcessor.ProcessRepositoryFilter(context.TODO(), "repo1", []int64{100}) + suite.Error(err) + suite.Nil(candidates) + } + + // simulate doublestar filtering + { + repoRecord3 := model.RepoRecord{ + RepositoryID: int64(1), + Name: "test/repo1/ubuntu", + ProjectID: int64(100), + Description: "test repo 1", + PullCount: 1, + StarCount: 4, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + } + repoRecord4 := model.RepoRecord{ + RepositoryID: int64(1), + Name: "test/repo1/centos", + ProjectID: int64(100), + Description: "test repo 2", + PullCount: 1, + StarCount: 4, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + } + allRepos = append(allRepos, &repoRecord3, &repoRecord4) + suite.repoMgr.On("List", mock.Anything, mock.Anything).Return(allRepos, nil).Once() + candidates, err := suite.filterProcessor.ProcessRepositoryFilter(context.TODO(), "repo1/**", []int64{100}) + suite.NoError(err) + suite.Equal(2, len(candidates), "Expected 2 candidate but found ", len(candidates)) + m := map[string]bool{} + for _, cand := range candidates { + m[cand.Repository] = true + } + _, ok := m["repo1/ubuntu"] + suite.True(ok) + _, ok = m["repo1/centos"] + suite.True(ok) + } +} + +func (suite *FilterProcessorTestSuite) TestProcessTagFilter() { + + testTag1 := tag2.Tag{ + ID: int64(1), + RepositoryID: int64(1), + ArtifactID: int64(1), + Name: "test-tag1", + PushTime: time.Time{}, + PullTime: time.Time{}, + } + + testTag2 := tag2.Tag{ + ID: int64(2), + RepositoryID: int64(1), + ArtifactID: int64(1), + Name: "test-tag2", + PushTime: time.Time{}, + PullTime: time.Time{}, + } + + testTag3 := tag2.Tag{ + ID: int64(3), + RepositoryID: int64(2), + ArtifactID: int64(2), + Name: "test-tag3", + PushTime: time.Time{}, + PullTime: time.Time{}, + } + + allTags := make([]*tag2.Tag, 0) + + allTags = append(allTags, &testTag1, &testTag2) + + // filter required repositories haveing the specified tags + { + suite.tagMgr.On("List", mock.Anything, mock.Anything).Return([]*tag2.Tag{&testTag1, &testTag2}, nil).Once() + suite.tagMgr.On("List", mock.Anything, mock.Anything).Return([]*tag2.Tag{&testTag3}, nil).Once() + + candidates, err := suite.filterProcessor.ProcessTagFilter(context.TODO(), "*tag2", []int64{1, 2}) + suite.NoError(err) + suite.Equal(1, len(candidates), "Expected 1 candidate but found ", len(candidates)) + suite.Equal(int64(1), candidates[0].NamespaceID) + } + + // simulate repo manager returning an error + { + suite.tagMgr.On("List", mock.Anything, mock.Anything).Return(nil, errors.New("test error")).Once() + candidates, err := suite.filterProcessor.ProcessTagFilter(context.TODO(), "repo1", []int64{1, 2}) + suite.Error(err) + suite.Nil(candidates) + } + +} + +func TestFilterProcessorTestSuite(t *testing.T) { + suite.Run(t, &FilterProcessorTestSuite{}) +} diff --git a/src/pkg/scan/export/manager.go b/src/pkg/scan/export/manager.go new file mode 100644 index 000000000..7dbf39ddc --- /dev/null +++ b/src/pkg/scan/export/manager.go @@ -0,0 +1,233 @@ +package export + +import ( + "context" + "encoding/json" + "errors" + "fmt" + beego_orm "github.com/beego/beego/orm" + "github.com/goharbor/harbor/src/jobservice/logger" + "github.com/goharbor/harbor/src/lib/orm" + q2 "github.com/goharbor/harbor/src/lib/q" + "strconv" + "strings" +) + +const ( + VulnScanReportView = "vuln_scan_report" + VulnScanReportQuery = `select row_number() over() as result_row_id, project.project_id as project_id, project."name" as project_name, harbor_user.user_id as user_id, harbor_user.username as project_owner, repository.repository_id, repository.name as repository_name, +scanner_registration.id as scanner_id, scanner_registration."name" as scanner_name, +vulnerability_record.cve_id, vulnerability_record.package, vulnerability_record.severity, +vulnerability_record.cvss_score_v3, vulnerability_record.cvss_score_v2, vulnerability_record.cvss_vector_v3, vulnerability_record.cvss_vector_v2, vulnerability_record.cwe_ids from report_vulnerability_record inner join scan_report on report_vulnerability_record.report_uuid = scan_report.uuid +inner join artifact on scan_report.digest = artifact.digest +inner join artifact_reference on artifact.id = artifact_reference.child_id +inner join vulnerability_record on report_vulnerability_record.vuln_record_id = vulnerability_record.id +inner join project on artifact.project_id = project.project_id +inner join repository on artifact.repository_id = repository.repository_id +inner join tag on tag.repository_id = repository.repository_id +inner join harbor_user on project.owner_id = harbor_user.user_id +inner join scanner_registration on scan_report.registration_uuid = scanner_registration.uuid ` + ArtifactBylabelQueryTemplate = "select distinct artifact.id from artifact inner join label_reference on artifact.id = label_reference.artifact_id inner join harbor_label on label_reference.label_id = harbor_label.id and harbor_label.id in (%s)" + SQLAnd = " and " + RepositoryIDColumn = "repository.repository_id" + ProjectIDColumn = "project.project_id" + TagIDColumn = "tag.id" + ArtifactParentIDColumn = "artifact_reference.parent_id" + GroupBy = " group by " + GroupByCols = `package, vulnerability_record.severity, vulnerability_record.cve_id, project.project_id, harbor_user.user_id , +repository.repository_id, scanner_registration.id, vulnerability_record.cvss_score_v3, +vulnerability_record.cvss_score_v2, vulnerability_record.cvss_vector_v3, vulnerability_record.cvss_vector_v2, +vulnerability_record.cwe_ids` + JobModeExport = "export" + JobModeKey = "mode" +) + +var ( + Mgr = NewManager() +) + +// Params specifies the filters for controlling the scan data export process +type Params struct { + // cve ids + CVEIds string + + // A list of one or more labels for which to export the scan data, defaults to all if empty + Labels []int64 + + // A list of one or more projects for which to export the scan data, defaults to all if empty + Projects []int64 + + // A list of repositories for which to export the scan data, defaults to all if empty + Repositories []int64 + + // A list of tags for which to export the scan data, defaults to all if empty + Tags []int64 + + // PageNumber + PageNumber int64 + + // PageSize + PageSize int64 +} + +// FromJSON parses robot from json data +func (p *Params) FromJSON(jsonData string) error { + if len(jsonData) == 0 { + return errors.New("empty json data to parse") + } + + return json.Unmarshal([]byte(jsonData), p) +} + +// ToJSON marshals Robot to JSON data +func (p *Params) ToJSON() (string, error) { + data, err := json.Marshal(p) + if err != nil { + return "", err + } + + return string(data), nil +} + +type Manager interface { + Fetch(ctx context.Context, params Params) ([]Data, error) +} + +type exportManager struct { + exportDataFilter VulnerabilityDataSelector +} + +func NewManager() Manager { + return &exportManager{exportDataFilter: NewVulnerabilityDataSelector()} +} + +func (em *exportManager) Fetch(ctx context.Context, params Params) ([]Data, error) { + exportData := make([]Data, 0) + artifactIdsWithLabel, err := em.getArtifactsWithLabel(ctx, params.Labels) + if err != nil { + return nil, err + } + // if labels are present but no artifact ids were retrieved then return empty + // results + if len(params.Labels) > 0 && len(artifactIdsWithLabel) == 0 { + return exportData, nil + } + + rawSeter, _ := em.buildQuery(ctx, params, artifactIdsWithLabel) + _, err = rawSeter.QueryRows(&exportData) + if err != nil { + return nil, err + } + exportData, err = em.exportDataFilter.Select(exportData, CVEIDMatches, params.CVEIds) + if err != nil { + return nil, err + } + return exportData, nil +} + +func (em *exportManager) buildQuery(ctx context.Context, params Params, artifactsWithLabel []int64) (beego_orm.RawSeter, error) { + sql := VulnScanReportQuery + filterFragment, err := em.getFilters(ctx, params, artifactsWithLabel) + if err != nil { + return nil, err + } + if len(filterFragment) > 0 { + sql = fmt.Sprintf("%s %s %s %s %s", VulnScanReportQuery, SQLAnd, filterFragment, GroupBy, GroupByCols) + } + logger.Infof("SQL query : %s", sql) + ormer, err := orm.FromContext(ctx) + + if err != nil { + return nil, err + } + logger.Infof("Parameters : %v", params) + pageSize := params.PageSize + q := &q2.Query{ + Keywords: nil, + Sorts: nil, + PageNumber: params.PageNumber, + PageSize: pageSize, + Sorting: "", + } + logger.Infof("Query constructed : %v", q) + paginationParams := make([]interface{}, 0) + query, pageLimits := orm.PaginationOnRawSQL(q, sql, paginationParams) + logger.Infof("Final Paginated query : %s", query) + logger.Infof("Final pagination parameters %v", pageLimits) + return ormer.Raw(query, pageLimits), nil +} + +func (em *exportManager) getFilters(ctx context.Context, params Params, artifactsWithLabel []int64) (string, error) { + // it is required that the request payload contains only IDs of the + // projects, repositories, tags and label objects. + // only CVE ID fields can be strings + filters := make([]string, 0) + if params.Repositories != nil { + filters = em.buildIDFilterFragmentWithIn(params.Repositories, filters, RepositoryIDColumn) + } + if params.Projects != nil { + filters = em.buildIDFilterFragmentWithIn(params.Projects, filters, ProjectIDColumn) + } + if params.Tags != nil { + filters = em.buildIDFilterFragmentWithIn(params.Tags, filters, TagIDColumn) + } + + if len(artifactsWithLabel) > 0 { + filters = em.buildIDFilterFragmentWithIn(artifactsWithLabel, filters, ArtifactParentIDColumn) + } + + if len(filters) == 0 { + return "", nil + } + logger.Infof("All filters : %v", filters) + completeFilter := strings.Builder{} + for _, filter := range filters { + if completeFilter.Len() > 0 { + completeFilter.WriteString(SQLAnd) + } + completeFilter.WriteString(filter) + } + return completeFilter.String(), nil +} + +func (em *exportManager) buildIDFilterFragmentWithIn(ids []int64, filters []string, column string) []string { + if len(ids) == 0 { + return filters + } + strIds := make([]string, 0) + for _, id := range ids { + strIds = append(strIds, strconv.FormatInt(id, 10)) + } + filters = append(filters, fmt.Sprintf(" %s in (%s)", column, strings.Join(strIds, ","))) + return filters +} + +// utility method to get all child artifacts belonging to a parent containing +// the specified label ids. +// Within Harbor, labels are attached to the root artifact whereas scan results +// are associated with the child artifact. +func (em *exportManager) getArtifactsWithLabel(ctx context.Context, ids []int64) ([]int64, error) { + artifactIds := make([]int64, 0) + + if len(ids) == 0 { + return artifactIds, nil + } + strIds := make([]string, 0) + for _, id := range ids { + strIds = append(strIds, strconv.FormatInt(id, 10)) + } + artifactQuery := fmt.Sprintf(ArtifactBylabelQueryTemplate, strings.Join(strIds, ",")) + logger.Infof("Triggering artifact query : %s", artifactQuery) + + ormer, err := orm.FromContext(ctx) + if err != nil { + return nil, err + } + numRows, err := ormer.Raw(artifactQuery).QueryRows(&artifactIds) + if err != nil { + return nil, err + } + logger.Infof("Found %d artifacts with specified tags", numRows) + + return artifactIds, nil +} diff --git a/src/pkg/scan/export/manager_test.go b/src/pkg/scan/export/manager_test.go new file mode 100644 index 000000000..dd78f322e --- /dev/null +++ b/src/pkg/scan/export/manager_test.go @@ -0,0 +1,371 @@ +package export + +import ( + "fmt" + "github.com/gogo/protobuf/sortkeys" + "github.com/goharbor/harbor/src/jobservice/job" + "github.com/goharbor/harbor/src/lib/q" + artifactDao "github.com/goharbor/harbor/src/pkg/artifact/dao" + labelDao "github.com/goharbor/harbor/src/pkg/label/dao" + labelModel "github.com/goharbor/harbor/src/pkg/label/model" + projectDao "github.com/goharbor/harbor/src/pkg/project/dao" + repoDao "github.com/goharbor/harbor/src/pkg/repository/dao" + "github.com/goharbor/harbor/src/pkg/repository/model" + daoscan "github.com/goharbor/harbor/src/pkg/scan/dao/scan" + "github.com/goharbor/harbor/src/pkg/scan/dao/scanner" + v1 "github.com/goharbor/harbor/src/pkg/scan/rest/v1" + tagDao "github.com/goharbor/harbor/src/pkg/tag/dao" + "github.com/goharbor/harbor/src/pkg/tag/model/tag" + userDao "github.com/goharbor/harbor/src/pkg/user/dao" + htesting "github.com/goharbor/harbor/src/testing" + "github.com/stretchr/testify/suite" + "testing" + "time" +) + +const RegistrationUUID = "scannerIdExportData" +const ReportUUD = "reportUUId" + +type ExportManagerSuite struct { + htesting.Suite + artifactDao artifactDao.DAO + projectDao projectDao.DAO + userDao userDao.DAO + repositoryDao repoDao.DAO + tagDao tagDao.DAO + scanDao daoscan.DAO + vulnDao daoscan.VulnerabilityRecordDao + labelDao labelDao.DAO + exportManager Manager + testDataId testDataIds +} + +type testDataIds struct { + artifactId []int64 + repositoryId []int64 + artRefId []int64 + reportId []int64 + tagId []int64 + vulnRecs []int64 + labelId []int64 + labelRefId []int64 +} + +func (suite *ExportManagerSuite) SetupSuite() { + suite.Suite.SetupSuite() + suite.artifactDao = artifactDao.New() + suite.projectDao = projectDao.New() + suite.userDao = userDao.New() + suite.repositoryDao = repoDao.New() + suite.tagDao = tagDao.New() + suite.scanDao = daoscan.New() + suite.vulnDao = daoscan.NewVulnerabilityRecordDao() + suite.labelDao = labelDao.New() + suite.exportManager = NewManager() + suite.setupTestData() +} + +func (suite *ExportManagerSuite) TearDownSuite() { + suite.clearTestData() +} + +func (suite *ExportManagerSuite) SetupTest() { + +} + +func (suite *ExportManagerSuite) TearDownTest() { + +} + +func (suite *ExportManagerSuite) clearTestData() { + + for _, labelRefId := range suite.testDataId.labelRefId { + err := suite.labelDao.DeleteReference(suite.Context(), labelRefId) + suite.NoError(err) + } + + // delete labels and label references + for _, labelId := range suite.testDataId.labelId { + err := suite.labelDao.Delete(suite.Context(), labelId) + suite.NoError(err) + } + + for _, artRefId := range suite.testDataId.artRefId { + err := suite.artifactDao.DeleteReference(suite.Context(), artRefId) + suite.NoError(err) + } + + for _, repoId := range suite.testDataId.repositoryId { + err := suite.repositoryDao.Delete(suite.Context(), repoId) + suite.NoError(err) + } + + for _, tagId := range suite.testDataId.tagId { + err := suite.tagDao.Delete(suite.Context(), tagId) + suite.NoError(err) + } + + for _, artId := range suite.testDataId.artifactId { + err := suite.artifactDao.Delete(suite.Context(), artId) + suite.NoError(err) + } + + err := scanner.DeleteRegistration(suite.Context(), RegistrationUUID) + suite.NoError(err, "Error when cleaning up scanner registrations") + + suite.cleanUpAdditionalData(ReportUUD, RegistrationUUID) +} + +func (suite *ExportManagerSuite) TestExport() { + // no label based filtering + { + data, err := suite.exportManager.Fetch(suite.Context(), Params{}) + suite.NoError(err) + suite.Equal(10, len(data)) + } + + // with label based filtering. all 10 records should be got back since there is a + // label associated with the artifact + { + p := Params{ + Labels: suite.testDataId.labelId, + } + data, err := suite.exportManager.Fetch(suite.Context(), p) + suite.NoError(err) + suite.Equal(10, len(data)) + } + + // with label based filtering and specifying a non-existent label Id. + // should return 0 records + { + allLabels, err := suite.labelDao.List(suite.Context(), &q.Query{}) + suite.NoError(err) + allLabelIds := make([]int64, 0) + for _, lbl := range allLabels { + allLabelIds = append(allLabelIds, lbl.ID) + } + sortkeys.Int64s(allLabelIds) + + p := Params{ + Labels: []int64{allLabelIds[len(allLabelIds)-1] + int64(1)}, + } + data, err := suite.exportManager.Fetch(suite.Context(), p) + suite.NoError(err) + suite.Equal(0, len(data)) + } + + // specify a project id that does not exist in the system. + { + p := Params{ + Projects: []int64{int64(-1)}, + } + data, err := suite.exportManager.Fetch(suite.Context(), p) + suite.NoError(err) + suite.Equal(0, len(data)) + } + + // specify a non existent repository + { + p := Params{ + Repositories: []int64{int64(-1)}, + } + data, err := suite.exportManager.Fetch(suite.Context(), p) + suite.NoError(err) + suite.Equal(0, len(data)) + } + + // specify a non-existent tag + { + p := Params{ + Tags: []int64{int64(-1)}, + } + data, err := suite.exportManager.Fetch(suite.Context(), p) + suite.NoError(err) + suite.Equal(0, len(data)) + } +} + +func (suite *ExportManagerSuite) TestExportWithCVEFilter() { + { + p := Params{ + CVEIds: "CVE-ID2", + } + data, err := suite.exportManager.Fetch(suite.Context(), p) + suite.NoError(err) + suite.Equal(1, len(data)) + suite.Equal(p.CVEIds, data[0].CVEId) + } +} + +func (suite *ExportManagerSuite) registerScanner(registrationUUID string) { + r := &scanner.Registration{ + UUID: registrationUUID, + Name: registrationUUID, + Description: "sample registration", + URL: fmt.Sprintf("https://sample.scanner.com/%s", registrationUUID), + } + + _, err := scanner.AddRegistration(suite.Context(), r) + suite.NoError(err, "add new registration") +} + +func (suite *ExportManagerSuite) generateVulnerabilityRecordsForReport(registrationUUID string, numRecords int) []*daoscan.VulnerabilityRecord { + vulns := make([]*daoscan.VulnerabilityRecord, 0) + for i := 1; i <= numRecords; i++ { + vulnV2 := new(daoscan.VulnerabilityRecord) + vulnV2.CVEID = fmt.Sprintf("CVE-ID%d", i) + vulnV2.Package = fmt.Sprintf("Package%d", i) + vulnV2.PackageVersion = "NotAvailable" + vulnV2.PackageType = "Unknown" + vulnV2.Fix = "1.0.0" + vulnV2.URLs = "url1" + vulnV2.RegistrationUUID = registrationUUID + if i%2 == 0 { + vulnV2.Severity = "High" + } else if i%3 == 0 { + vulnV2.Severity = "Medium" + } else if i%4 == 0 { + vulnV2.Severity = "Critical" + } else { + vulnV2.Severity = "Low" + } + vulns = append(vulns, vulnV2) + } + + return vulns +} + +func (suite *ExportManagerSuite) insertVulnRecordForReport(reportUUID string, vr *daoscan.VulnerabilityRecord) { + id, err := suite.vulnDao.Create(suite.Context(), vr) + suite.NoError(err, "Failed to create vulnerability record") + suite.testDataId.vulnRecs = append(suite.testDataId.vulnRecs, id) + + err = suite.vulnDao.InsertForReport(suite.Context(), reportUUID, id) + suite.NoError(err, "Failed to insert vulnerability record row for report %s", reportUUID) +} + +func (suite *ExportManagerSuite) cleanUpAdditionalData(reportID string, scannerID string) { + _, err := suite.scanDao.DeleteMany(suite.Context(), q.Query{Keywords: q.KeyWords{"uuid": reportID}}) + + suite.NoError(err) + _, err = suite.vulnDao.DeleteForReport(suite.Context(), reportID) + suite.NoError(err, "Failed to cleanup records") + _, err = suite.vulnDao.DeleteForScanner(suite.Context(), scannerID) + suite.NoError(err, "Failed to delete vulnerability records") +} + +func (suite *ExportManagerSuite) setupTestData() { + // create repositories + repoRecord := &model.RepoRecord{ + Name: "library/ubuntu", + ProjectID: 1, + Description: "", + PullCount: 1, + StarCount: 0, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + } + repoId, err := suite.repositoryDao.Create(suite.Context(), repoRecord) + suite.NoError(err) + suite.testDataId.repositoryId = append(suite.testDataId.repositoryId, repoId) + + // create artifacts for repositories + art := &artifactDao.Artifact{ + Type: "IMAGE", + MediaType: "application/vnd.docker.container.image.v1+json", + ManifestMediaType: "application/vnd.docker.distribution.manifest.v2+json", + ProjectID: 1, + RepositoryID: repoId, + RepositoryName: "library/ubuntu", + Digest: "sha256:e3d7ff9efd8431d9ef39a144c45992df5502c995b9ba3c53ff70c5b52a848d9c", + Size: 28573056, + Icon: "", + PushTime: time.Time{}, + PullTime: time.Time{}.Add(-10 * time.Minute), + ExtraAttrs: `{"architecture":"amd64","author":"","config":{"Env":["PATH=/usr/local/sbin:/usr/local/bin:/usr/sbin:/usr/bin:/sbin:/bin"],"Cmd":["/bin/bash"]},"created":"2021-03-04T02:24:42.927713926Z","os":"linux"}`, + Annotations: "", + } + artId, err := suite.artifactDao.Create(suite.Context(), art) + suite.NoError(err) + suite.testDataId.artifactId = append(suite.testDataId.artifactId, artId) + + // create a tag and associate with the repository + t := &tag.Tag{ + RepositoryID: repoId, + ArtifactID: artId, + Name: "latest", + PushTime: time.Time{}, + PullTime: time.Time{}, + } + tagId, err := suite.tagDao.Create(suite.Context(), t) + suite.NoError(err) + suite.testDataId.tagId = append(suite.testDataId.tagId, tagId) + + // create an artifact reference + artReference := &artifactDao.ArtifactReference{ + ParentID: artId, + ChildID: artId, + ChildDigest: "sha256:e3d7ff9efd8431d9ef39a144c45992df5502c995b9ba3c53ff70c5b52a848d9c", + Platform: `{"architecture":"amd64","os":"linux"}`, + URLs: "", + Annotations: "", + } + artRefId, err := suite.artifactDao.CreateReference(suite.Context(), artReference) + suite.NoError(err) + suite.testDataId.artRefId = append(suite.testDataId.artRefId, artRefId) + + // create a label + l := labelModel.Label{ + Name: "TestLabel", + Description: "", + Color: "Green", + Level: "", + Scope: "", + ProjectID: 1, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + Deleted: false, + } + labelId, err := suite.labelDao.Create(suite.Context(), &l) + suite.NoError(err) + suite.testDataId.labelId = append(suite.testDataId.labelId, labelId) + + lRef := labelModel.Reference{ + ID: 0, + LabelID: labelId, + ArtifactID: artId, + CreationTime: time.Time{}, + UpdateTime: time.Time{}, + } + lRefId, err := suite.labelDao.CreateReference(suite.Context(), &lRef) + suite.NoError(err) + suite.testDataId.labelRefId = append(suite.testDataId.labelRefId, lRefId) + + // register a scanner + suite.registerScanner(RegistrationUUID) + + // create a vulnerability scan report + r := &daoscan.Report{ + UUID: ReportUUD, + Digest: "sha256:e3d7ff9efd8431d9ef39a144c45992df5502c995b9ba3c53ff70c5b52a848d9c", + RegistrationUUID: RegistrationUUID, + MimeType: v1.MimeTypeGenericVulnerabilityReport, + Status: job.PendingStatus.String(), + Report: "", + } + reportId, err := suite.scanDao.Create(suite.Context(), r) + suite.NoError(err) + suite.testDataId.reportId = append(suite.testDataId.reportId, reportId) + + // generate vulnerability records for the report + vulns := suite.generateVulnerabilityRecordsForReport(RegistrationUUID, 10) + suite.NotEmpty(vulns) + + for _, vuln := range vulns { + suite.insertVulnRecordForReport(ReportUUD, vuln) + } +} + +func TestExportManager(t *testing.T) { + suite.Run(t, &ExportManagerSuite{}) +} diff --git a/src/pkg/scan/export/model.go b/src/pkg/scan/export/model.go new file mode 100644 index 000000000..5c6c18645 --- /dev/null +++ b/src/pkg/scan/export/model.go @@ -0,0 +1,115 @@ +package export + +import ( + "encoding/json" + "errors" + "time" +) + +// Data models a single row of the exported scan vulnerability data + +type Data struct { + ID int64 `orm:"column(result_row_id)" csv:"RowId"` + ProjectName string `orm:"column(project_name)" csv:"Project"` + ProjectOwner string `orm:"column(project_owner)" csv:"Owner"` + ScannerName string `orm:"column(scanner_name)" csv:"Scanner"` + Repository string `orm:"column(repository_name)" csv:"Repository"` + ArtifactDigest string `orm:"column(artifact_digest)" csv:"Artifact Digest"` + CVEId string `orm:"column(cve_id)" csv:"CVE"` + Package string `orm:"column(package)" csv:"Package"` + Severity string `orm:"column(severity)" csv:"Severity"` + CVSSScoreV3 string `orm:"column(cvss_score_v3)" csv:"CVSS V3 Score"` + CVSSScoreV2 string `orm:"column(cvss_score_v2)" csv:"CVSS V2 Score"` + CVSSVectorV3 string `orm:"column(cvss_vector_v3)" csv:"CVSS V3 Vector"` + CVSSVectorV2 string `orm:"column(cvss_vector_v2)" csv:"CVSS V2 Vector"` + CWEIds string `orm:"column(cwe_ids)" csv:"CWE Ids"` +} + +// Request encapsulates the filters to be provided when exporting the data for a scan. +type Request struct { + + // UserID contains the database identity of the user initiating the export request + UserID int + + // UserName contains the name of the user initiating the export request + UserName string + + // JobName contains the name of the job as specified by the external client. + JobName string + + // cve ids + CVEIds string + + // A list of one or more labels for which to export the scan data, defaults to all if empty + Labels []int64 + + // A list of one or more projects for which to export the scan data, defaults to all if empty + Projects []int64 + + // A list of repositories for which to export the scan data, defaults to all if empty + Repositories string + + // A list of tags for which to export the scan data, defaults to all if empty + Tags string +} + +// FromJSON parses robot from json data +func (c *Request) FromJSON(jsonData string) error { + if len(jsonData) == 0 { + return errors.New("empty json data to parse") + } + + return json.Unmarshal([]byte(jsonData), c) +} + +// ToJSON marshals Robot to JSON data +func (c *Request) ToJSON() (string, error) { + data, err := json.Marshal(c) + if err != nil { + return "", err + } + + return string(data), nil +} + +// Execution provides details about the running status of a scan data export job +type Execution struct { + // ID of the execution + ID int64 + // UserID triggering the execution + UserID int64 + // Status provides the status of the execution + Status string + // StatusMessage contains the human-readable status message for the execution + StatusMessage string + // Trigger indicates the mode of trigger for the job execution + Trigger string + // StartTime contains the start time instant of the execution + StartTime time.Time + // EndTime contains the end time instant of the execution + EndTime time.Time + // ExportDataDigest contains the SHA256 hash of the exported scan data artifact + ExportDataDigest string + // Name of the job as specified during the export task invocation + JobName string + // Name of the user triggering the job + UserName string + // FilePresent is true if file artifact is actually present, false otherwise + FilePresent bool +} + +type Task struct { + // ID of the scan data export task + ID int64 + // Job Id corresponding to the task + JobID string + // Status of the current task execution + Status string + // Status message for the current task execution + StatusMessage string +} + +type TriggerParam struct { + TimeWindowMinutes int + PageSize int +} diff --git a/src/pkg/systemartifact/manager.go b/src/pkg/systemartifact/manager.go index 47b287c68..e2cb54a62 100644 --- a/src/pkg/systemartifact/manager.go +++ b/src/pkg/systemartifact/manager.go @@ -11,6 +11,7 @@ import ( "github.com/goharbor/harbor/src/pkg/systemartifact/model" "io" "sync" + "time" ) var ( @@ -18,8 +19,8 @@ var ( keyFormat = "%s:%s" ) -const repositoryFormat = "sys_harbor/%s/%s" -const systemArtifactProjectName = "sys_h@rb0r" +const repositoryFormat = "sys_harb0r/%s/%s" +const systemArtifactProjectName = "sys_harb0r" // Manager provides a low-level interface for harbor services // to create registry artifacts containing arbitrary data but which @@ -91,7 +92,10 @@ func NewManager() Manager { func (mgr *systemArtifactManager) Create(ctx context.Context, artifactRecord *model.SystemArtifact, reader io.Reader) (int64, error) { var artifactID int64 - + // create time defaults to current time if unset + if artifactRecord.CreateTime.IsZero() { + artifactRecord.CreateTime = time.Now() + } // the entire create operation is executed within a transaction to ensure that any failures // during the blob creation or tracking record creation result in a rollback of the transaction createError := orm.WithTransaction(func(ctx context.Context) error { @@ -103,6 +107,7 @@ func (mgr *systemArtifactManager) Create(ctx context.Context, artifactRecord *mo repoName := mgr.getRepositoryName(artifactRecord.Vendor, artifactRecord.Repository) err = mgr.regCli.PushBlob(repoName, artifactRecord.Digest, artifactRecord.Size, reader) if err != nil { + log.Errorf("Error creating system artifact record for %s/%s/%s: %v", artifactRecord.Vendor, artifactRecord.Repository, artifactRecord.Digest, err) return err } artifactID = id diff --git a/src/pkg/systemartifact/manager_test.go b/src/pkg/systemartifact/manager_test.go index d3af08c99..f62796f95 100644 --- a/src/pkg/systemartifact/manager_test.go +++ b/src/pkg/systemartifact/manager_test.go @@ -63,6 +63,24 @@ func (suite *ManagerTestSuite) TestCreate() { suite.regCli.AssertCalled(suite.T(), "PushBlob") } +func (suite *ManagerTestSuite) TestCreateTimeNotSet() { + sa := model.SystemArtifact{ + Repository: "test_repo", + Digest: "test_digest", + Size: int64(100), + Vendor: "test_vendor", + Type: "test_type", + } + suite.dao.On("Create", mock.Anything, &sa, mock.Anything).Return(int64(1), nil).Once() + suite.regCli.On("PushBlob", mock.Anything, mock.Anything, mock.Anything, mock.Anything).Return(nil).Once() + reader := strings.NewReader("test data string") + id, err := suite.mgr.Create(orm.NewContext(nil, &ormtesting.FakeOrmer{}), &sa, reader) + suite.Equalf(int64(1), id, "Expected row to correctly inserted") + suite.NoErrorf(err, "Unexpected error when creating artifact: %v", err) + suite.regCli.AssertCalled(suite.T(), "PushBlob") + suite.False(sa.CreateTime.IsZero(), "Create time expected to be set") +} + func (suite *ManagerTestSuite) TestCreatePushBlobFails() { sa := model.SystemArtifact{ Repository: "test_repo", @@ -171,7 +189,7 @@ func (suite *ManagerTestSuite) TestDelete() { suite.NoErrorf(err, "Unexpected error when deleting artifact: %v", err) suite.dao.AssertCalled(suite.T(), "Delete", mock.Anything, "test_vendor", "test_repo", "test_digest") - suite.regCli.AssertCalled(suite.T(), "DeleteBlob") + suite.regCli.AssertCalled(suite.T(), "DeleteBlob", "sys_harb0r/test_vendor/test_repo", "test_digest") } func (suite *ManagerTestSuite) TestDeleteSystemArtifactDeleteError() { @@ -184,7 +202,7 @@ func (suite *ManagerTestSuite) TestDeleteSystemArtifactDeleteError() { suite.Errorf(err, "Expected error when deleting artifact: %v", err) suite.dao.AssertCalled(suite.T(), "Delete", mock.Anything, "test_vendor", "test_repo", "test_digest") - suite.regCli.AssertCalled(suite.T(), "DeleteBlob") + suite.regCli.AssertCalled(suite.T(), "DeleteBlob", "sys_harb0r/test_vendor/test_repo", "test_digest") } func (suite *ManagerTestSuite) TestDeleteSystemArtifactBlobDeleteError() { @@ -197,7 +215,7 @@ func (suite *ManagerTestSuite) TestDeleteSystemArtifactBlobDeleteError() { suite.Errorf(err, "Expected error when deleting artifact: %v", err) suite.dao.AssertNotCalled(suite.T(), "Delete", mock.Anything, "test_vendor", "test_repo", "test_digest") - suite.regCli.AssertCalled(suite.T(), "DeleteBlob") + suite.regCli.AssertCalled(suite.T(), "DeleteBlob", "sys_harb0r/test_vendor/test_repo", "test_digest") } func (suite *ManagerTestSuite) TestExist() { diff --git a/src/server/v2.0/handler/handler.go b/src/server/v2.0/handler/handler.go index e4a141741..ea0380ea6 100644 --- a/src/server/v2.0/handler/handler.go +++ b/src/server/v2.0/handler/handler.go @@ -65,6 +65,7 @@ func New() http.Handler { StatisticAPI: newStatisticAPI(), ProjectMetadataAPI: newProjectMetadaAPI(), PurgeAPI: newPurgeAPI(), + ScanDataExportAPI: newScanDataExportAPI(), }) if err != nil { log.Fatal(err) diff --git a/src/server/v2.0/handler/scanexport.go b/src/server/v2.0/handler/scanexport.go new file mode 100644 index 000000000..8d70f3219 --- /dev/null +++ b/src/server/v2.0/handler/scanexport.go @@ -0,0 +1,237 @@ +package handler + +import ( + "context" + "fmt" + "github.com/goharbor/harbor/src/pkg/user" + "io" + "net/http" + "strings" + + "github.com/go-openapi/runtime" + "github.com/go-openapi/runtime/middleware" + "github.com/go-openapi/strfmt" + "github.com/goharbor/harbor/src/common/rbac" + "github.com/goharbor/harbor/src/controller/scandataexport" + "github.com/goharbor/harbor/src/jobservice/logger" + "github.com/goharbor/harbor/src/lib/log" + "github.com/goharbor/harbor/src/lib/orm" + "github.com/goharbor/harbor/src/pkg/scan/export" + v1 "github.com/goharbor/harbor/src/pkg/scan/rest/v1" + "github.com/goharbor/harbor/src/pkg/systemartifact" + "github.com/goharbor/harbor/src/server/v2.0/models" + operation "github.com/goharbor/harbor/src/server/v2.0/restapi/operations/scan_data_export" +) + +func newScanDataExportAPI() *scanDataExportAPI { + return &scanDataExportAPI{ + scanDataExportCtl: scandataexport.Ctl, + sysArtifactMgr: systemartifact.Mgr, + userMgr: user.Mgr, + } +} + +type scanDataExportAPI struct { + BaseAPI + scanDataExportCtl scandataexport.Controller + sysArtifactMgr systemartifact.Manager + userMgr user.Manager +} + +func (se *scanDataExportAPI) Prepare(ctx context.Context, operation string, params interface{}) middleware.Responder { + return nil +} + +func (se *scanDataExportAPI) ExportScanData(ctx context.Context, params operation.ExportScanDataParams) middleware.Responder { + if err := se.RequireAuthenticated(ctx); err != nil { + return se.SendError(ctx, err) + } + + // check if the MIME type for the export is the Generic vulnerability data + if params.XScanDataType != v1.MimeTypeGenericVulnerabilityReport { + error := &models.Error{Message: fmt.Sprintf("Unsupported MIME type : %s", params.XScanDataType)} + errors := &models.Errors{Errors: []*models.Error{error}} + return operation.NewExportScanDataBadRequest().WithPayload(errors) + } + + // loop through the list of projects and validate that scan privilege and create privilege + // is available for all projects + // TODO : Should we just ignore projects that do not have the required level of access? + + projects := params.Criteria.Projects + for _, project := range projects { + if err := se.RequireProjectAccess(ctx, project, rbac.ActionCreate, rbac.ResourceScan); err != nil { + return se.SendError(ctx, err) + } + } + + scanDataExportJob := new(models.ScanDataExportJob) + + secContext, err := se.GetSecurityContext(ctx) + + if err != nil { + return se.SendError(ctx, err) + } + + // vendor id associated with the job == the user id + usr, err := se.userMgr.GetByName(ctx, secContext.GetUsername()) + + if err != nil { + return se.SendError(ctx, err) + } + + if usr == nil { + error := &models.Error{Message: fmt.Sprintf("User : %s not found", secContext.GetUsername())} + errors := &models.Errors{Errors: []*models.Error{error}} + return operation.NewExportScanDataForbidden().WithPayload(errors) + } + + userContext := context.WithValue(ctx, export.CsvJobVendorIDKey, usr.UserID) + + if err != nil { + return se.SendError(ctx, err) + } + + jobID, err := se.scanDataExportCtl.Start(userContext, se.convertToCriteria(params.Criteria, secContext.GetUsername(), usr.UserID)) + if err != nil { + return se.SendError(ctx, err) + } + scanDataExportJob.ID = jobID + return operation.NewExportScanDataOK().WithPayload(scanDataExportJob) +} + +func (se *scanDataExportAPI) GetScanDataExportExecution(ctx context.Context, params operation.GetScanDataExportExecutionParams) middleware.Responder { + err := se.RequireAuthenticated(ctx) + if err != nil { + return se.SendError(ctx, err) + } + execution, err := se.scanDataExportCtl.GetExecution(ctx, params.ExecutionID) + if err != nil { + return se.SendError(ctx, err) + } + sdeExec := models.ScanDataExportExecution{ + EndTime: strfmt.DateTime(execution.EndTime), + ID: execution.ID, + StartTime: strfmt.DateTime(execution.StartTime), + Status: execution.Status, + StatusText: execution.StatusMessage, + Trigger: execution.Trigger, + UserID: execution.UserID, + JobName: execution.JobName, + UserName: execution.UserName, + FilePresent: execution.FilePresent, + } + + return operation.NewGetScanDataExportExecutionOK().WithPayload(&sdeExec) +} + +func (se *scanDataExportAPI) DownloadScanData(ctx context.Context, params operation.DownloadScanDataParams) middleware.Responder { + err := se.RequireAuthenticated(ctx) + if err != nil { + return se.SendError(ctx, err) + } + execution, err := se.scanDataExportCtl.GetExecution(ctx, params.ExecutionID) + if err != nil { + se.SendError(ctx, err) + } + + // check if the execution being downloaded is owned by the current user + secContext, err := se.GetSecurityContext(ctx) + if err != nil { + return se.SendError(ctx, err) + } + + if secContext.GetUsername() != execution.UserName { + return middleware.ResponderFunc(func(writer http.ResponseWriter, producer runtime.Producer) { + writer.WriteHeader(http.StatusUnauthorized) + }) + } + + repositoryName := fmt.Sprintf("scandata_export_%v", params.ExecutionID) + file, err := se.sysArtifactMgr.Read(ctx, strings.ToLower(export.Vendor), repositoryName, execution.ExportDataDigest) + if err != nil { + return se.SendError(ctx, err) + } + logger.Infof("reading data from file : %s", repositoryName) + + return middleware.ResponderFunc(func(writer http.ResponseWriter, producer runtime.Producer) { + defer se.cleanUpArtifact(ctx, repositoryName, execution.ExportDataDigest, params.ExecutionID, file) + + writer.Header().Set("Content-Type", "text/csv") + writer.Header().Set("Content-Disposition", fmt.Sprintf("attachment; filename=%q", fmt.Sprintf("%s.csv", repositoryName))) + nbytes, err := io.Copy(writer, file) + if err != nil { + logger.Errorf("Encountered error while copying data: %v", err) + } else { + logger.Debugf("Copied %v bytes from file to client", nbytes) + } + }) +} + +func (se *scanDataExportAPI) GetScanDataExportExecutionList(ctx context.Context, params operation.GetScanDataExportExecutionListParams) middleware.Responder { + err := se.RequireAuthenticated(ctx) + if err != nil { + return se.SendError(ctx, err) + } + executions, err := se.scanDataExportCtl.ListExecutions(ctx, params.UserName) + if err != nil { + return se.SendError(ctx, err) + } + execs := make([]*models.ScanDataExportExecution, 0) + for _, execution := range executions { + sdeExec := &models.ScanDataExportExecution{ + EndTime: strfmt.DateTime(execution.EndTime), + ID: execution.ID, + StartTime: strfmt.DateTime(execution.StartTime), + Status: execution.Status, + StatusText: execution.StatusMessage, + Trigger: execution.Trigger, + UserID: execution.UserID, + UserName: execution.UserName, + JobName: execution.JobName, + FilePresent: execution.FilePresent, + } + execs = append(execs, sdeExec) + } + sdeExecList := models.ScanDataExportExecutionList{Items: execs} + return operation.NewGetScanDataExportExecutionListOK().WithPayload(&sdeExecList) +} + +func (se *scanDataExportAPI) convertToCriteria(requestCriteria *models.ScanDataExportRequest, userName string, userID int) export.Request { + return export.Request{ + UserID: userID, + UserName: userName, + JobName: requestCriteria.JobName, + CVEIds: requestCriteria.CVEIds, + Labels: requestCriteria.Labels, + Projects: requestCriteria.Projects, + Repositories: requestCriteria.Repositories, + Tags: requestCriteria.Tags, + } +} + +func (se *scanDataExportAPI) cleanUpArtifact(ctx context.Context, repositoryName, digest string, execID int64, file io.ReadCloser) { + file.Close() + logger.Infof("Deleting report artifact : %v:%v", repositoryName, digest) + + // the entire delete operation is executed within a transaction to ensure that any failures + // during the blob creation or tracking record creation result in a rollback of the transaction + vendor := strings.ToLower(export.Vendor) + err := orm.WithTransaction(func(ctx context.Context) error { + err := se.sysArtifactMgr.Delete(ctx, vendor, repositoryName, digest) + if err != nil { + log.Errorf("Error deleting system artifact record for %s/%s/%s: %v", vendor, repositoryName, digest, err) + return err + } + // delete the underlying execution + err = se.scanDataExportCtl.DeleteExecution(ctx, execID) + if err != nil { + log.Errorf("Error deleting csv export job execution for %s/%s/%s: %v", vendor, repositoryName, digest, err) + } + return err + })(ctx) + + if err != nil { + log.Errorf("Error deleting system artifact record for %s/%s/%s: %v", vendor, repositoryName, digest, err) + } +} diff --git a/src/server/v2.0/handler/scanexport_test.go b/src/server/v2.0/handler/scanexport_test.go new file mode 100644 index 000000000..0fa5a3e90 --- /dev/null +++ b/src/server/v2.0/handler/scanexport_test.go @@ -0,0 +1,396 @@ +package handler + +import ( + "bytes" + "encoding/json" + "errors" + "fmt" + commonmodels "github.com/goharbor/harbor/src/common/models" + "github.com/goharbor/harbor/src/pkg/scan/export" + v1 "github.com/goharbor/harbor/src/pkg/scan/rest/v1" + "github.com/goharbor/harbor/src/server/v2.0/models" + "github.com/goharbor/harbor/src/server/v2.0/restapi" + "github.com/goharbor/harbor/src/testing/controller/scandataexport" + "github.com/goharbor/harbor/src/testing/mock" + systemartifacttesting "github.com/goharbor/harbor/src/testing/pkg/systemartifact" + "github.com/goharbor/harbor/src/testing/pkg/user" + htesting "github.com/goharbor/harbor/src/testing/server/v2.0/handler" + testifymock "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/suite" + "io" + "net/http" + url2 "net/url" + "strings" + "testing" + "time" +) + +type ScanExportTestSuite struct { + htesting.Suite + scanExportCtl *scandataexport.Controller + sysArtifactMgr *systemartifacttesting.Manager + userMgr *user.Manager +} + +func (suite *ScanExportTestSuite) SetupSuite() { + +} + +func (suite *ScanExportTestSuite) SetupTest() { + + suite.scanExportCtl = &scandataexport.Controller{} + suite.sysArtifactMgr = &systemartifacttesting.Manager{} + suite.userMgr = &user.Manager{} + suite.Config = &restapi.Config{ + ScanDataExportAPI: &scanDataExportAPI{ + scanDataExportCtl: suite.scanExportCtl, + sysArtifactMgr: suite.sysArtifactMgr, + userMgr: suite.userMgr, + }, + } + suite.Suite.SetupSuite() +} + +func (suite *ScanExportTestSuite) TestAuthorization() { + { + criteria := models.ScanDataExportRequest{ + CVEIds: "CVE-123", + Labels: []int64{100}, + Projects: []int64{200}, + Repositories: "test-repo", + Tags: "{test-tag1, test-tag2}", + } + + reqs := []struct { + method string + url string + body interface{} + headers map[string]string + }{ + {http.MethodPost, "/export/cve", criteria, map[string]string{"X-Scan-Data-Type": v1.MimeTypeGenericVulnerabilityReport}}, + {http.MethodGet, "/export/cve/execution/100", nil, nil}, + {http.MethodGet, "/export/cve/download/100", nil, nil}, + } + + suite.Security.On("IsAuthenticated").Return(false).Times(3) + for _, req := range reqs { + + if req.body != nil && req.method == http.MethodPost { + data, _ := json.Marshal(criteria) + buffer := bytes.NewBuffer(data) + res, _ := suite.DoReq(req.method, req.url, buffer, req.headers) + suite.Equal(401, res.StatusCode) + } else { + res, _ := suite.DoReq(req.method, req.url, nil) + suite.Equal(401, res.StatusCode) + } + + } + } +} +func (suite *ScanExportTestSuite) TestExportScanData() { + suite.Security.On("GetUsername").Return("test-user") + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Once() + usr := commonmodels.User{UserID: 1000, Username: "test-user"} + suite.userMgr.On("GetByName", mock.Anything, "test-user").Return(&usr, nil).Once() + // user authenticated and correct headers sent + { + suite.Security.On("IsAuthenticated").Return(true).Once() + url := "/export/cve" + criteria := models.ScanDataExportRequest{ + JobName: "test-job", + CVEIds: "CVE-123", + Labels: []int64{100}, + Projects: []int64{200}, + Repositories: "test-repo", + Tags: "{test-tag1, test-tag2}", + } + + data, err := json.Marshal(criteria) + buffer := bytes.NewBuffer(data) + + headers := make(map[string]string) + headers["X-Scan-Data-Type"] = v1.MimeTypeGenericVulnerabilityReport + + // data, err := json.Marshal(criteria) + mock.OnAnything(suite.scanExportCtl, "Start").Return(int64(100), nil).Once() + res, err := suite.DoReq(http.MethodPost, url, buffer, headers) + suite.Equal(200, res.StatusCode) + + suite.Equal(nil, err) + respData := make(map[string]interface{}) + json.NewDecoder(res.Body).Decode(&respData) + suite.Equal(int64(100), int64(respData["id"].(float64))) + + // validate job name and user name set in the request for job execution + jobRequestMatcher := testifymock.MatchedBy(func(req export.Request) bool { + return req.UserName == "test-user" && req.JobName == "test-job" && req.Tags == "{test-tag1, test-tag2}" && req.UserID == 1000 + }) + suite.scanExportCtl.AssertCalled(suite.T(), "Start", mock.Anything, jobRequestMatcher) + } + + // user authenticated but incorrect/unsupported header sent across + { + suite.Security.On("IsAuthenticated").Return(true).Once() + url := "/export/cve" + + criteria := models.ScanDataExportRequest{ + CVEIds: "CVE-123", + Labels: []int64{100}, + Projects: []int64{200}, + Repositories: "test-repo", + Tags: "{test-tag1, test-tag2}", + } + + data, err := json.Marshal(criteria) + buffer := bytes.NewBuffer(data) + + headers := make(map[string]string) + headers["X-Scan-Data-Type"] = "test" + + mock.OnAnything(suite.scanExportCtl, "Start").Return(int64(100), nil).Once() + res, err := suite.DoReq(http.MethodPost, url, buffer, headers) + suite.Equal(400, res.StatusCode) + suite.Equal(nil, err) + } + +} + +func (suite *ScanExportTestSuite) TestExportScanDataGetUserIdError() { + suite.Security.On("GetUsername").Return("test-user") + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Once() + suite.userMgr.On("GetByName", mock.Anything, "test-user").Return(nil, errors.New("test error")).Once() + // user authenticated and correct headers sent + { + suite.Security.On("IsAuthenticated").Return(true).Once() + url := "/export/cve" + criteria := models.ScanDataExportRequest{ + JobName: "test-job", + CVEIds: "CVE-123", + Labels: []int64{100}, + Projects: []int64{200}, + Repositories: "test-repo", + Tags: "{test-tag1, test-tag2}", + } + + data, err := json.Marshal(criteria) + buffer := bytes.NewBuffer(data) + + headers := make(map[string]string) + headers["X-Scan-Data-Type"] = v1.MimeTypeGenericVulnerabilityReport + + // data, err := json.Marshal(criteria) + mock.OnAnything(suite.scanExportCtl, "Start").Return(int64(100), nil).Once() + res, err := suite.DoReq(http.MethodPost, url, buffer, headers) + suite.Equal(http.StatusInternalServerError, res.StatusCode) + suite.Equal(nil, err) + + suite.scanExportCtl.AssertNotCalled(suite.T(), "Start") + } +} + +func (suite *ScanExportTestSuite) TestExportScanDataGetUserIdNotFound() { + suite.Security.On("GetUsername").Return("test-user") + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Once() + suite.userMgr.On("GetByName", mock.Anything, "test-user").Return(nil, nil).Once() + // user authenticated and correct headers sent + { + suite.Security.On("IsAuthenticated").Return(true).Once() + url := "/export/cve" + criteria := models.ScanDataExportRequest{ + JobName: "test-job", + CVEIds: "CVE-123", + Labels: []int64{100}, + Projects: []int64{200}, + Repositories: "test-repo", + Tags: "{test-tag1, test-tag2}", + } + + data, err := json.Marshal(criteria) + buffer := bytes.NewBuffer(data) + + headers := make(map[string]string) + headers["X-Scan-Data-Type"] = v1.MimeTypeGenericVulnerabilityReport + + // data, err := json.Marshal(criteria) + mock.OnAnything(suite.scanExportCtl, "Start").Return(int64(100), nil).Once() + res, err := suite.DoReq(http.MethodPost, url, buffer, headers) + suite.Equal(http.StatusForbidden, res.StatusCode) + suite.Equal(nil, err) + + suite.scanExportCtl.AssertNotCalled(suite.T(), "Start") + } +} + +func (suite *ScanExportTestSuite) TestExportScanDataNoPrivileges() { + suite.Security.On("IsAuthenticated").Return(true).Times(2) + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(false).Once() + url := "/export/cve" + + criteria := models.ScanDataExportRequest{ + JobName: "test-job", + CVEIds: "CVE-123", + Labels: []int64{100}, + Projects: []int64{200}, + Repositories: "test-repo", + Tags: "{test-tag1, test-tag2}", + } + + data, err := json.Marshal(criteria) + buffer := bytes.NewBuffer(data) + + headers := make(map[string]string) + headers["X-Scan-Data-Type"] = v1.MimeTypeGenericVulnerabilityReport + + mock.OnAnything(suite.scanExportCtl, "Start").Return(int64(100), nil).Once() + res, err := suite.DoReq(http.MethodPost, url, buffer, headers) + suite.Equal(http.StatusForbidden, res.StatusCode) + suite.NoError(err) +} + +func (suite *ScanExportTestSuite) TestGetScanDataExportExecution() { + + suite.Security.On("IsAuthenticated").Return(true).Once() + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Once() + url := "/export/cve/execution/100" + endTime := time.Now() + startTime := endTime.Add(-10 * time.Minute) + + execution := &export.Execution{ + ID: 100, + UserID: 3, + Status: "Success", + StatusMessage: "", + Trigger: "MANUAL", + StartTime: startTime, + EndTime: endTime, + ExportDataDigest: "datadigest", + UserName: "test-user", + JobName: "test-job", + FilePresent: false, + } + mock.OnAnything(suite.scanExportCtl, "GetExecution").Return(execution, nil).Once() + res, err := suite.DoReq(http.MethodGet, url, nil) + suite.Equal(200, res.StatusCode) + suite.Equal(nil, err) + respData := models.ScanDataExportExecution{} + json.NewDecoder(res.Body).Decode(&respData) + suite.Equal("test-user", respData.UserName) + suite.Equal("test-job", respData.JobName) + suite.Equal(false, respData.FilePresent) + +} + +func (suite *ScanExportTestSuite) TestDownloadScanData() { + suite.Security.On("GetUsername").Return("test-user") + suite.Security.On("IsAuthenticated").Return(true).Once() + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(1) + url := "/export/cve/download/100" + endTime := time.Now() + startTime := endTime.Add(-10 * time.Minute) + + execution := &export.Execution{ + ID: int64(100), + UserID: int64(3), + Status: "Success", + StatusMessage: "", + Trigger: "MANUAL", + StartTime: startTime, + EndTime: endTime, + ExportDataDigest: "datadigest", + UserName: "test-user", + } + mock.OnAnything(suite.scanExportCtl, "GetExecution").Return(execution, nil) + mock.OnAnything(suite.scanExportCtl, "DeleteExecution").Return(nil) + + // all BLOB related operations succeed + mock.OnAnything(suite.sysArtifactMgr, "Create").Return(int64(1), nil) + + sampleData := "test,hello,world" + data := io.NopCloser(strings.NewReader(sampleData)) + mock.OnAnything(suite.sysArtifactMgr, "Read").Return(data, nil) + mock.OnAnything(suite.sysArtifactMgr, "Delete").Return(nil) + + res, err := suite.DoReq(http.MethodGet, url, nil) + suite.Equal(200, res.StatusCode) + suite.Equal(nil, err) + + // validate the content of the response + var responseData bytes.Buffer + if _, err := io.Copy(&responseData, res.Body); err == nil { + suite.Equal(sampleData, responseData.String()) + } +} + +func (suite *ScanExportTestSuite) TestDownloadScanDataUserNotOwnerofExport() { + suite.Security.On("GetUsername").Return("test-user1") + suite.Security.On("IsAuthenticated").Return(true).Once() + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Times(1) + url := "/export/cve/download/100" + endTime := time.Now() + startTime := endTime.Add(-10 * time.Minute) + + execution := &export.Execution{ + ID: int64(100), + UserID: int64(3), + Status: "Success", + StatusMessage: "", + Trigger: "MANUAL", + StartTime: startTime, + EndTime: endTime, + ExportDataDigest: "datadigest", + UserName: "test-user", + } + mock.OnAnything(suite.scanExportCtl, "GetExecution").Return(execution, nil) + mock.OnAnything(suite.scanExportCtl, "DeleteExecution").Return(nil) + + // all BLOB related operations succeed + mock.OnAnything(suite.sysArtifactMgr, "Create").Return(int64(1), nil) + + sampleData := "test,hello,world" + data := io.NopCloser(strings.NewReader(sampleData)) + mock.OnAnything(suite.sysArtifactMgr, "Read").Return(data, nil) + mock.OnAnything(suite.sysArtifactMgr, "Delete").Return(nil) + + res, err := suite.DoReq(http.MethodGet, url, nil) + suite.Equal(http.StatusUnauthorized, res.StatusCode) + suite.Equal(nil, err) +} + +func (suite *ScanExportTestSuite) TestGetScanDataExportExecutionList() { + + suite.Security.On("IsAuthenticated").Return(true).Once() + suite.Security.On("Can", mock.Anything, mock.Anything, mock.Anything).Return(true).Once() + url, err := url2.Parse("/export/cve/executions") + params := url2.Values{} + params.Add("user_name", "test-user") + url.RawQuery = params.Encode() + endTime := time.Now() + startTime := endTime.Add(-10 * time.Minute) + + execution := &export.Execution{ + ID: 100, + UserID: 3, + Status: "Success", + StatusMessage: "", + Trigger: "MANUAL", + StartTime: startTime, + EndTime: endTime, + ExportDataDigest: "datadigest", + JobName: "test-job", + UserName: "test-user", + } + fmt.Println("URL string : ", url.String()) + mock.OnAnything(suite.scanExportCtl, "ListExecutions").Return([]*export.Execution{execution}, nil).Once() + res, err := suite.DoReq(http.MethodGet, url.String(), nil) + suite.Equal(200, res.StatusCode) + suite.Equal(nil, err) + respData := models.ScanDataExportExecutionList{} + json.NewDecoder(res.Body).Decode(&respData) + suite.Equal(1, len(respData.Items)) + suite.Equal(int64(100), respData.Items[0].ID) +} + +func TestScanExportTestSuite(t *testing.T) { + suite.Run(t, &ScanExportTestSuite{}) +} diff --git a/src/testing/controller/controller.go b/src/testing/controller/controller.go index 68623c6e5..3fb3bd510 100644 --- a/src/testing/controller/controller.go +++ b/src/testing/controller/controller.go @@ -31,3 +31,4 @@ package controller //go:generate mockery --case snake --dir ../../controller/purge --name Controller --output ./purge --outpkg purge //go:generate mockery --case snake --dir ../../controller/jobservice --name SchedulerController --output ./jobservice --outpkg jobservice //go:generate mockery --case snake --dir ../../controller/systemartifact --name Controller --output ./systemartifact --outpkg systemartifact +//go:generate mockery --case snake --dir ../../controller/scandataexport --name Controller --output ./scandataexport --outpkg scandataexport diff --git a/src/testing/controller/scandataexport/controller.go b/src/testing/controller/scandataexport/controller.go new file mode 100644 index 000000000..ce6e65f54 --- /dev/null +++ b/src/testing/controller/scandataexport/controller.go @@ -0,0 +1,136 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package scandataexport + +import ( + context "context" + + export "github.com/goharbor/harbor/src/pkg/scan/export" + mock "github.com/stretchr/testify/mock" + + task "github.com/goharbor/harbor/src/pkg/task" +) + +// Controller is an autogenerated mock type for the Controller type +type Controller struct { + mock.Mock +} + +// DeleteExecution provides a mock function with given fields: ctx, executionID +func (_m *Controller) DeleteExecution(ctx context.Context, executionID int64) error { + ret := _m.Called(ctx, executionID) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, int64) error); ok { + r0 = rf(ctx, executionID) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// GetExecution provides a mock function with given fields: ctx, executionID +func (_m *Controller) GetExecution(ctx context.Context, executionID int64) (*export.Execution, error) { + ret := _m.Called(ctx, executionID) + + var r0 *export.Execution + if rf, ok := ret.Get(0).(func(context.Context, int64) *export.Execution); ok { + r0 = rf(ctx, executionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*export.Execution) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, executionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetTask provides a mock function with given fields: ctx, executionID +func (_m *Controller) GetTask(ctx context.Context, executionID int64) (*task.Task, error) { + ret := _m.Called(ctx, executionID) + + var r0 *task.Task + if rf, ok := ret.Get(0).(func(context.Context, int64) *task.Task); ok { + r0 = rf(ctx, executionID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*task.Task) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, int64) error); ok { + r1 = rf(ctx, executionID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListExecutions provides a mock function with given fields: ctx, userName +func (_m *Controller) ListExecutions(ctx context.Context, userName string) ([]*export.Execution, error) { + ret := _m.Called(ctx, userName) + + var r0 []*export.Execution + if rf, ok := ret.Get(0).(func(context.Context, string) []*export.Execution); ok { + r0 = rf(ctx, userName) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*export.Execution) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, userName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Start provides a mock function with given fields: ctx, criteria +func (_m *Controller) Start(ctx context.Context, criteria export.Request) (int64, error) { + ret := _m.Called(ctx, criteria) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, export.Request) int64); ok { + r0 = rf(ctx, criteria) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, export.Request) error); ok { + r1 = rf(ctx, criteria) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NewControllerT interface { + mock.TestingT + Cleanup(func()) +} + +// NewController creates a new instance of Controller. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewController(t NewControllerT) *Controller { + mock := &Controller{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/src/testing/pkg/pkg.go b/src/testing/pkg/pkg.go index d4f5b8ed3..ca2acbd5e 100644 --- a/src/testing/pkg/pkg.go +++ b/src/testing/pkg/pkg.go @@ -62,3 +62,7 @@ package pkg //go:generate mockery --case snake --dir ../../pkg/systemartifact/ --name Selector --output ./systemartifact/cleanup --outpkg cleanup //go:generate mockery --case snake --dir ../../pkg/systemartifact/dao --name DAO --output ./systemartifact/dao --outpkg dao //go:generate mockery --case snake --dir ../../pkg/cached/manifest/redis --name CachedManager --output ./cached/manifest/redis --outpkg redis +//go:generate mockery --case snake --dir ../../pkg/scan/export --name FilterProcessor --output ./scan/export --outpkg export +//go:generate mockery --case snake --dir ../../pkg/scan/export --name Manager --output ./scan/export --outpkg export +//go:generate mockery --case snake --dir ../../pkg/scan/export --name ArtifactDigestCalculator --output ./scan/export --outpkg export +//go:generate mockery --case snake --dir ../../pkg/registry --name Client --output ./registry --outpkg registry --filename fake_registry_client.go diff --git a/src/testing/pkg/registry/client.go b/src/testing/pkg/registry/client.go index 837e74796..608493b84 100644 --- a/src/testing/pkg/registry/client.go +++ b/src/testing/pkg/registry/client.go @@ -114,7 +114,7 @@ func (f *FakeClient) MountBlob(srcRepository, digest, dstRepository string) (err // DeleteBlob ... func (f *FakeClient) DeleteBlob(repository, digest string) (err error) { - args := f.Called() + args := f.Called(repository, digest) return args.Error(0) } diff --git a/src/testing/pkg/registry/fake_registry_client.go b/src/testing/pkg/registry/fake_registry_client.go new file mode 100644 index 000000000..813789a67 --- /dev/null +++ b/src/testing/pkg/registry/fake_registry_client.go @@ -0,0 +1,325 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package registry + +import ( + http "net/http" + + distribution "github.com/docker/distribution" + + io "io" + + mock "github.com/stretchr/testify/mock" +) + +// Client is an autogenerated mock type for the Client type +type Client struct { + mock.Mock +} + +// BlobExist provides a mock function with given fields: repository, digest +func (_m *Client) BlobExist(repository string, digest string) (bool, error) { + ret := _m.Called(repository, digest) + + var r0 bool + if rf, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = rf(repository, digest) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string) error); ok { + r1 = rf(repository, digest) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Catalog provides a mock function with given fields: +func (_m *Client) Catalog() ([]string, error) { + ret := _m.Called() + + var r0 []string + if rf, ok := ret.Get(0).(func() []string); ok { + r0 = rf() + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func() error); ok { + r1 = rf() + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Copy provides a mock function with given fields: srcRepository, srcReference, dstRepository, dstReference, override +func (_m *Client) Copy(srcRepository string, srcReference string, dstRepository string, dstReference string, override bool) error { + ret := _m.Called(srcRepository, srcReference, dstRepository, dstReference, override) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, string, string, bool) error); ok { + r0 = rf(srcRepository, srcReference, dstRepository, dstReference, override) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteBlob provides a mock function with given fields: repository, digest +func (_m *Client) DeleteBlob(repository string, digest string) error { + ret := _m.Called(repository, digest) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(repository, digest) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteManifest provides a mock function with given fields: repository, reference +func (_m *Client) DeleteManifest(repository string, reference string) error { + ret := _m.Called(repository, reference) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string) error); ok { + r0 = rf(repository, reference) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Do provides a mock function with given fields: req +func (_m *Client) Do(req *http.Request) (*http.Response, error) { + ret := _m.Called(req) + + var r0 *http.Response + if rf, ok := ret.Get(0).(func(*http.Request) *http.Response); ok { + r0 = rf(req) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(*http.Response) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(*http.Request) error); ok { + r1 = rf(req) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ListTags provides a mock function with given fields: repository +func (_m *Client) ListTags(repository string) ([]string, error) { + ret := _m.Called(repository) + + var r0 []string + if rf, ok := ret.Get(0).(func(string) []string); ok { + r0 = rf(repository) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]string) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(repository) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ManifestExist provides a mock function with given fields: repository, reference +func (_m *Client) ManifestExist(repository string, reference string) (bool, *distribution.Descriptor, error) { + ret := _m.Called(repository, reference) + + var r0 bool + if rf, ok := ret.Get(0).(func(string, string) bool); ok { + r0 = rf(repository, reference) + } else { + r0 = ret.Get(0).(bool) + } + + var r1 *distribution.Descriptor + if rf, ok := ret.Get(1).(func(string, string) *distribution.Descriptor); ok { + r1 = rf(repository, reference) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(*distribution.Descriptor) + } + } + + var r2 error + if rf, ok := ret.Get(2).(func(string, string) error); ok { + r2 = rf(repository, reference) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// MountBlob provides a mock function with given fields: srcRepository, digest, dstRepository +func (_m *Client) MountBlob(srcRepository string, digest string, dstRepository string) error { + ret := _m.Called(srcRepository, digest, dstRepository) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, string) error); ok { + r0 = rf(srcRepository, digest, dstRepository) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// Ping provides a mock function with given fields: +func (_m *Client) Ping() error { + ret := _m.Called() + + var r0 error + if rf, ok := ret.Get(0).(func() error); ok { + r0 = rf() + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PullBlob provides a mock function with given fields: repository, digest +func (_m *Client) PullBlob(repository string, digest string) (int64, io.ReadCloser, error) { + ret := _m.Called(repository, digest) + + var r0 int64 + if rf, ok := ret.Get(0).(func(string, string) int64); ok { + r0 = rf(repository, digest) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 io.ReadCloser + if rf, ok := ret.Get(1).(func(string, string) io.ReadCloser); ok { + r1 = rf(repository, digest) + } else { + if ret.Get(1) != nil { + r1 = ret.Get(1).(io.ReadCloser) + } + } + + var r2 error + if rf, ok := ret.Get(2).(func(string, string) error); ok { + r2 = rf(repository, digest) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// PullManifest provides a mock function with given fields: repository, reference, acceptedMediaTypes +func (_m *Client) PullManifest(repository string, reference string, acceptedMediaTypes ...string) (distribution.Manifest, string, error) { + _va := make([]interface{}, len(acceptedMediaTypes)) + for _i := range acceptedMediaTypes { + _va[_i] = acceptedMediaTypes[_i] + } + var _ca []interface{} + _ca = append(_ca, repository, reference) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 distribution.Manifest + if rf, ok := ret.Get(0).(func(string, string, ...string) distribution.Manifest); ok { + r0 = rf(repository, reference, acceptedMediaTypes...) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).(distribution.Manifest) + } + } + + var r1 string + if rf, ok := ret.Get(1).(func(string, string, ...string) string); ok { + r1 = rf(repository, reference, acceptedMediaTypes...) + } else { + r1 = ret.Get(1).(string) + } + + var r2 error + if rf, ok := ret.Get(2).(func(string, string, ...string) error); ok { + r2 = rf(repository, reference, acceptedMediaTypes...) + } else { + r2 = ret.Error(2) + } + + return r0, r1, r2 +} + +// PushBlob provides a mock function with given fields: repository, digest, size, blob +func (_m *Client) PushBlob(repository string, digest string, size int64, blob io.Reader) error { + ret := _m.Called(repository, digest, size, blob) + + var r0 error + if rf, ok := ret.Get(0).(func(string, string, int64, io.Reader) error); ok { + r0 = rf(repository, digest, size, blob) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// PushManifest provides a mock function with given fields: repository, reference, mediaType, payload +func (_m *Client) PushManifest(repository string, reference string, mediaType string, payload []byte) (string, error) { + ret := _m.Called(repository, reference, mediaType, payload) + + var r0 string + if rf, ok := ret.Get(0).(func(string, string, string, []byte) string); ok { + r0 = rf(repository, reference, mediaType, payload) + } else { + r0 = ret.Get(0).(string) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string, string, string, []byte) error); ok { + r1 = rf(repository, reference, mediaType, payload) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NewClientT interface { + mock.TestingT + Cleanup(func()) +} + +// NewClient creates a new instance of Client. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewClient(t NewClientT) *Client { + mock := &Client{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/src/testing/pkg/scan/dao/scan/vulnerability_record_dao.go b/src/testing/pkg/scan/dao/scan/vulnerability_record_dao.go new file mode 100644 index 000000000..a41217961 --- /dev/null +++ b/src/testing/pkg/scan/dao/scan/vulnerability_record_dao.go @@ -0,0 +1,256 @@ +// Code generated by mockery v2.1.0. DO NOT EDIT. + +package scan + +import ( + context "context" + + q "github.com/goharbor/harbor/src/lib/q" + mock "github.com/stretchr/testify/mock" + + scan "github.com/goharbor/harbor/src/pkg/scan/dao/scan" +) + +// VulnerabilityRecordDao is an autogenerated mock type for the VulnerabilityRecordDao type +type VulnerabilityRecordDao struct { + mock.Mock +} + +// Create provides a mock function with given fields: ctx, vr +func (_m *VulnerabilityRecordDao) Create(ctx context.Context, vr *scan.VulnerabilityRecord) (int64, error) { + ret := _m.Called(ctx, vr) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, *scan.VulnerabilityRecord) int64); ok { + r0 = rf(ctx, vr) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *scan.VulnerabilityRecord) error); ok { + r1 = rf(ctx, vr) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Delete provides a mock function with given fields: ctx, vr +func (_m *VulnerabilityRecordDao) Delete(ctx context.Context, vr *scan.VulnerabilityRecord) error { + ret := _m.Called(ctx, vr) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *scan.VulnerabilityRecord) error); ok { + r0 = rf(ctx, vr) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// DeleteForDigests provides a mock function with given fields: ctx, digests +func (_m *VulnerabilityRecordDao) DeleteForDigests(ctx context.Context, digests ...string) (int64, error) { + _va := make([]interface{}, len(digests)) + for _i := range digests { + _va[_i] = digests[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, ...string) int64); ok { + r0 = rf(ctx, digests...) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, ...string) error); ok { + r1 = rf(ctx, digests...) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteForReport provides a mock function with given fields: ctx, reportUUID +func (_m *VulnerabilityRecordDao) DeleteForReport(ctx context.Context, reportUUID string) (int64, error) { + ret := _m.Called(ctx, reportUUID) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, reportUUID) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, reportUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// DeleteForScanner provides a mock function with given fields: ctx, registrationUUID +func (_m *VulnerabilityRecordDao) DeleteForScanner(ctx context.Context, registrationUUID string) (int64, error) { + ret := _m.Called(ctx, registrationUUID) + + var r0 int64 + if rf, ok := ret.Get(0).(func(context.Context, string) int64); ok { + r0 = rf(ctx, registrationUUID) + } else { + r0 = ret.Get(0).(int64) + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, registrationUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetForReport provides a mock function with given fields: ctx, reportUUID +func (_m *VulnerabilityRecordDao) GetForReport(ctx context.Context, reportUUID string) ([]*scan.VulnerabilityRecord, error) { + ret := _m.Called(ctx, reportUUID) + + var r0 []*scan.VulnerabilityRecord + if rf, ok := ret.Get(0).(func(context.Context, string) []*scan.VulnerabilityRecord); ok { + r0 = rf(ctx, reportUUID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*scan.VulnerabilityRecord) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, reportUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetForScanner provides a mock function with given fields: ctx, registrationUUID +func (_m *VulnerabilityRecordDao) GetForScanner(ctx context.Context, registrationUUID string) ([]*scan.VulnerabilityRecord, error) { + ret := _m.Called(ctx, registrationUUID) + + var r0 []*scan.VulnerabilityRecord + if rf, ok := ret.Get(0).(func(context.Context, string) []*scan.VulnerabilityRecord); ok { + r0 = rf(ctx, registrationUUID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*scan.VulnerabilityRecord) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, registrationUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// GetRecordIdsForScanner provides a mock function with given fields: ctx, registrationUUID +func (_m *VulnerabilityRecordDao) GetRecordIdsForScanner(ctx context.Context, registrationUUID string) ([]int, error) { + ret := _m.Called(ctx, registrationUUID) + + var r0 []int + if rf, ok := ret.Get(0).(func(context.Context, string) []int); ok { + r0 = rf(ctx, registrationUUID) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string) error); ok { + r1 = rf(ctx, registrationUUID) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// InsertForReport provides a mock function with given fields: ctx, reportUUID, vulnerabilityRecordIDs +func (_m *VulnerabilityRecordDao) InsertForReport(ctx context.Context, reportUUID string, vulnerabilityRecordIDs ...int64) error { + _va := make([]interface{}, len(vulnerabilityRecordIDs)) + for _i := range vulnerabilityRecordIDs { + _va[_i] = vulnerabilityRecordIDs[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, reportUUID) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, string, ...int64) error); ok { + r0 = rf(ctx, reportUUID, vulnerabilityRecordIDs...) + } else { + r0 = ret.Error(0) + } + + return r0 +} + +// List provides a mock function with given fields: ctx, query +func (_m *VulnerabilityRecordDao) List(ctx context.Context, query *q.Query) ([]*scan.VulnerabilityRecord, error) { + ret := _m.Called(ctx, query) + + var r0 []*scan.VulnerabilityRecord + if rf, ok := ret.Get(0).(func(context.Context, *q.Query) []*scan.VulnerabilityRecord); ok { + r0 = rf(ctx, query) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*scan.VulnerabilityRecord) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, *q.Query) error); ok { + r1 = rf(ctx, query) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// Update provides a mock function with given fields: ctx, vr, cols +func (_m *VulnerabilityRecordDao) Update(ctx context.Context, vr *scan.VulnerabilityRecord, cols ...string) error { + _va := make([]interface{}, len(cols)) + for _i := range cols { + _va[_i] = cols[_i] + } + var _ca []interface{} + _ca = append(_ca, ctx, vr) + _ca = append(_ca, _va...) + ret := _m.Called(_ca...) + + var r0 error + if rf, ok := ret.Get(0).(func(context.Context, *scan.VulnerabilityRecord, ...string) error); ok { + r0 = rf(ctx, vr, cols...) + } else { + r0 = ret.Error(0) + } + + return r0 +} diff --git a/src/testing/pkg/scan/export/artifact_digest_calculator.go b/src/testing/pkg/scan/export/artifact_digest_calculator.go new file mode 100644 index 000000000..dd8b76080 --- /dev/null +++ b/src/testing/pkg/scan/export/artifact_digest_calculator.go @@ -0,0 +1,50 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package export + +import ( + digest "github.com/opencontainers/go-digest" + + mock "github.com/stretchr/testify/mock" +) + +// ArtifactDigestCalculator is an autogenerated mock type for the ArtifactDigestCalculator type +type ArtifactDigestCalculator struct { + mock.Mock +} + +// Calculate provides a mock function with given fields: fileName +func (_m *ArtifactDigestCalculator) Calculate(fileName string) (digest.Digest, error) { + ret := _m.Called(fileName) + + var r0 digest.Digest + if rf, ok := ret.Get(0).(func(string) digest.Digest); ok { + r0 = rf(fileName) + } else { + r0 = ret.Get(0).(digest.Digest) + } + + var r1 error + if rf, ok := ret.Get(1).(func(string) error); ok { + r1 = rf(fileName) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NewArtifactDigestCalculatorT interface { + mock.TestingT + Cleanup(func()) +} + +// NewArtifactDigestCalculator creates a new instance of ArtifactDigestCalculator. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewArtifactDigestCalculator(t NewArtifactDigestCalculatorT) *ArtifactDigestCalculator { + mock := &ArtifactDigestCalculator{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/src/testing/pkg/scan/export/filter_processor.go b/src/testing/pkg/scan/export/filter_processor.go new file mode 100644 index 000000000..50ce99804 --- /dev/null +++ b/src/testing/pkg/scan/export/filter_processor.go @@ -0,0 +1,100 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package export + +import ( + context "context" + + mock "github.com/stretchr/testify/mock" + + selector "github.com/goharbor/harbor/src/lib/selector" +) + +// FilterProcessor is an autogenerated mock type for the FilterProcessor type +type FilterProcessor struct { + mock.Mock +} + +// ProcessProjectFilter provides a mock function with given fields: ctx, userName, projectsToFilter +func (_m *FilterProcessor) ProcessProjectFilter(ctx context.Context, userName string, projectsToFilter []int64) ([]int64, error) { + ret := _m.Called(ctx, userName, projectsToFilter) + + var r0 []int64 + if rf, ok := ret.Get(0).(func(context.Context, string, []int64) []int64); ok { + r0 = rf(ctx, userName, projectsToFilter) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]int64) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, []int64) error); ok { + r1 = rf(ctx, userName, projectsToFilter) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ProcessRepositoryFilter provides a mock function with given fields: ctx, filter, projectIds +func (_m *FilterProcessor) ProcessRepositoryFilter(ctx context.Context, filter string, projectIds []int64) ([]*selector.Candidate, error) { + ret := _m.Called(ctx, filter, projectIds) + + var r0 []*selector.Candidate + if rf, ok := ret.Get(0).(func(context.Context, string, []int64) []*selector.Candidate); ok { + r0 = rf(ctx, filter, projectIds) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*selector.Candidate) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, []int64) error); ok { + r1 = rf(ctx, filter, projectIds) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +// ProcessTagFilter provides a mock function with given fields: ctx, filter, repositoryIds +func (_m *FilterProcessor) ProcessTagFilter(ctx context.Context, filter string, repositoryIds []int64) ([]*selector.Candidate, error) { + ret := _m.Called(ctx, filter, repositoryIds) + + var r0 []*selector.Candidate + if rf, ok := ret.Get(0).(func(context.Context, string, []int64) []*selector.Candidate); ok { + r0 = rf(ctx, filter, repositoryIds) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]*selector.Candidate) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, string, []int64) error); ok { + r1 = rf(ctx, filter, repositoryIds) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NewFilterProcessorT interface { + mock.TestingT + Cleanup(func()) +} + +// NewFilterProcessor creates a new instance of FilterProcessor. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewFilterProcessor(t NewFilterProcessorT) *FilterProcessor { + mock := &FilterProcessor{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/src/testing/pkg/scan/export/manager.go b/src/testing/pkg/scan/export/manager.go new file mode 100644 index 000000000..c66fa69b3 --- /dev/null +++ b/src/testing/pkg/scan/export/manager.go @@ -0,0 +1,53 @@ +// Code generated by mockery v2.12.3. DO NOT EDIT. + +package export + +import ( + context "context" + + export "github.com/goharbor/harbor/src/pkg/scan/export" + mock "github.com/stretchr/testify/mock" +) + +// Manager is an autogenerated mock type for the Manager type +type Manager struct { + mock.Mock +} + +// Fetch provides a mock function with given fields: ctx, params +func (_m *Manager) Fetch(ctx context.Context, params export.Params) ([]export.Data, error) { + ret := _m.Called(ctx, params) + + var r0 []export.Data + if rf, ok := ret.Get(0).(func(context.Context, export.Params) []export.Data); ok { + r0 = rf(ctx, params) + } else { + if ret.Get(0) != nil { + r0 = ret.Get(0).([]export.Data) + } + } + + var r1 error + if rf, ok := ret.Get(1).(func(context.Context, export.Params) error); ok { + r1 = rf(ctx, params) + } else { + r1 = ret.Error(1) + } + + return r0, r1 +} + +type NewManagerT interface { + mock.TestingT + Cleanup(func()) +} + +// NewManager creates a new instance of Manager. It also registers a testing interface on the mock and a cleanup function to assert the mocks expectations. +func NewManager(t NewManagerT) *Manager { + mock := &Manager{} + mock.Mock.Test(t) + + t.Cleanup(func() { mock.AssertExpectations(t) }) + + return mock +} diff --git a/src/vendor/github.com/gocarina/gocsv/.gitignore b/src/vendor/github.com/gocarina/gocsv/.gitignore new file mode 100644 index 000000000..485dee64b --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/.gitignore @@ -0,0 +1 @@ +.idea diff --git a/src/vendor/github.com/gocarina/gocsv/.travis.yml b/src/vendor/github.com/gocarina/gocsv/.travis.yml new file mode 100644 index 000000000..61c24c6c9 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/.travis.yml @@ -0,0 +1,4 @@ +language: go +arch: + - amd64 + - ppc64le diff --git a/src/vendor/github.com/gocarina/gocsv/LICENSE b/src/vendor/github.com/gocarina/gocsv/LICENSE new file mode 100644 index 000000000..052a37119 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/LICENSE @@ -0,0 +1,21 @@ +The MIT License (MIT) + +Copyright (c) 2014 Jonathan Picques + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/src/vendor/github.com/gocarina/gocsv/README.md b/src/vendor/github.com/gocarina/gocsv/README.md new file mode 100644 index 000000000..e606ac7bb --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/README.md @@ -0,0 +1,168 @@ +Go CSV +===== + +The GoCSV package aims to provide easy serialization and deserialization functions to use CSV in Golang + +API and techniques inspired from https://godoc.org/gopkg.in/mgo.v2 + +[![GoDoc](https://godoc.org/github.com/gocarina/gocsv?status.png)](https://godoc.org/github.com/gocarina/gocsv) +[![Build Status](https://travis-ci.org/gocarina/gocsv.svg?branch=master)](https://travis-ci.org/gocarina/gocsv) + +Installation +===== + +```go get -u github.com/gocarina/gocsv``` + +Full example +===== + +Consider the following CSV file + +```csv + +client_id,client_name,client_age +1,Jose,42 +2,Daniel,26 +3,Vincent,32 + +``` + +Easy binding in Go! +--- + +```go + +package main + +import ( + "fmt" + "os" + + "github.com/gocarina/gocsv" +) + +type Client struct { // Our example struct, you can use "-" to ignore a field + Id string `csv:"client_id"` + Name string `csv:"client_name"` + Age string `csv:"client_age"` + NotUsed string `csv:"-"` +} + +func main() { + clientsFile, err := os.OpenFile("clients.csv", os.O_RDWR|os.O_CREATE, os.ModePerm) + if err != nil { + panic(err) + } + defer clientsFile.Close() + + clients := []*Client{} + + if err := gocsv.UnmarshalFile(clientsFile, &clients); err != nil { // Load clients from file + panic(err) + } + for _, client := range clients { + fmt.Println("Hello", client.Name) + } + + if _, err := clientsFile.Seek(0, 0); err != nil { // Go to the start of the file + panic(err) + } + + clients = append(clients, &Client{Id: "12", Name: "John", Age: "21"}) // Add clients + clients = append(clients, &Client{Id: "13", Name: "Fred"}) + clients = append(clients, &Client{Id: "14", Name: "James", Age: "32"}) + clients = append(clients, &Client{Id: "15", Name: "Danny"}) + csvContent, err := gocsv.MarshalString(&clients) // Get all clients as CSV string + //err = gocsv.MarshalFile(&clients, clientsFile) // Use this to save the CSV back to the file + if err != nil { + panic(err) + } + fmt.Println(csvContent) // Display all clients as CSV string + +} + +``` + +Customizable Converters +--- + +```go + +type DateTime struct { + time.Time +} + +// Convert the internal date as CSV string +func (date *DateTime) MarshalCSV() (string, error) { + return date.Time.Format("20060201"), nil +} + +// You could also use the standard Stringer interface +func (date *DateTime) String() (string) { + return date.String() // Redundant, just for example +} + +// Convert the CSV string as internal date +func (date *DateTime) UnmarshalCSV(csv string) (err error) { + date.Time, err = time.Parse("20060201", csv) + return err +} + +type Client struct { // Our example struct with a custom type (DateTime) + Id string `csv:"id"` + Name string `csv:"name"` + Employed DateTime `csv:"employed"` +} + +``` + +Customizable CSV Reader / Writer +--- + +```go + +func main() { + ... + + gocsv.SetCSVReader(func(in io.Reader) gocsv.CSVReader { + r := csv.NewReader(in) + r.Comma = '|' + return r // Allows use pipe as delimiter + }) + + ... + + gocsv.SetCSVReader(func(in io.Reader) gocsv.CSVReader { + r := csv.NewReader(in) + r.LazyQuotes = true + r.Comma = '.' + return r // Allows use dot as delimiter and use quotes in CSV + }) + + ... + + gocsv.SetCSVReader(func(in io.Reader) gocsv.CSVReader { + //return csv.NewReader(in) + return gocsv.LazyCSVReader(in) // Allows use of quotes in CSV + }) + + ... + + gocsv.UnmarshalFile(file, &clients) + + ... + + gocsv.SetCSVWriter(func(out io.Writer) *SafeCSVWriter { + writer := csv.NewWriter(out) + writer.Comma = '|' + return gocsv.NewSafeCSVWriter(writer) + }) + + ... + + gocsv.MarshalFile(&clients, file) + + ... +} + +``` diff --git a/src/vendor/github.com/gocarina/gocsv/csv.go b/src/vendor/github.com/gocarina/gocsv/csv.go new file mode 100644 index 000000000..a20e2d709 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/csv.go @@ -0,0 +1,517 @@ +// Copyright 2014 Jonathan Picques. All rights reserved. +// Use of this source code is governed by a MIT license +// The license can be found in the LICENSE file. + +// The GoCSV package aims to provide easy CSV serialization and deserialization to the golang programming language + +package gocsv + +import ( + "bytes" + "encoding/csv" + "fmt" + "io" + "os" + "reflect" + "strings" + "sync" +) + +// FailIfUnmatchedStructTags indicates whether it is considered an error when there is an unmatched +// struct tag. +var FailIfUnmatchedStructTags = false + +// FailIfDoubleHeaderNames indicates whether it is considered an error when a header name is repeated +// in the csv header. +var FailIfDoubleHeaderNames = false + +// ShouldAlignDuplicateHeadersWithStructFieldOrder indicates whether we should align duplicate CSV +// headers per their alignment in the struct definition. +var ShouldAlignDuplicateHeadersWithStructFieldOrder = false + +// TagName defines key in the struct field's tag to scan +var TagName = "csv" + +// TagSeparator defines seperator string for multiple csv tags in struct fields +var TagSeparator = "," + +// Normalizer is a function that takes and returns a string. It is applied to +// struct and header field values before they are compared. It can be used to alter +// names for comparison. For instance, you could allow case insensitive matching +// or convert '-' to '_'. +type Normalizer func(string) string + +type ErrorHandler func(*csv.ParseError) bool + +// normalizeName function initially set to a nop Normalizer. +var normalizeName = DefaultNameNormalizer() + +// DefaultNameNormalizer is a nop Normalizer. +func DefaultNameNormalizer() Normalizer { return func(s string) string { return s } } + +// SetHeaderNormalizer sets the normalizer used to normalize struct and header field names. +func SetHeaderNormalizer(f Normalizer) { + normalizeName = f + // Need to clear the cache hen the header normalizer changes. + structInfoCache = sync.Map{} +} + +// -------------------------------------------------------------------------- +// CSVWriter used to format CSV + +var selfCSVWriter = DefaultCSVWriter + +// DefaultCSVWriter is the default SafeCSVWriter used to format CSV (cf. csv.NewWriter) +func DefaultCSVWriter(out io.Writer) *SafeCSVWriter { + writer := NewSafeCSVWriter(csv.NewWriter(out)) + + // As only one rune can be defined as a CSV separator, we are going to trim + // the custom tag separator and use the first rune. + if runes := []rune(strings.TrimSpace(TagSeparator)); len(runes) > 0 { + writer.Comma = runes[0] + } + + return writer +} + +// SetCSVWriter sets the SafeCSVWriter used to format CSV. +func SetCSVWriter(csvWriter func(io.Writer) *SafeCSVWriter) { + selfCSVWriter = csvWriter +} + +func getCSVWriter(out io.Writer) *SafeCSVWriter { + return selfCSVWriter(out) +} + +// -------------------------------------------------------------------------- +// CSVReader used to parse CSV + +var selfCSVReader = DefaultCSVReader + +// DefaultCSVReader is the default CSV reader used to parse CSV (cf. csv.NewReader) +func DefaultCSVReader(in io.Reader) CSVReader { + return csv.NewReader(in) +} + +// LazyCSVReader returns a lazy CSV reader, with LazyQuotes and TrimLeadingSpace. +func LazyCSVReader(in io.Reader) CSVReader { + csvReader := csv.NewReader(in) + csvReader.LazyQuotes = true + csvReader.TrimLeadingSpace = true + return csvReader +} + +// SetCSVReader sets the CSV reader used to parse CSV. +func SetCSVReader(csvReader func(io.Reader) CSVReader) { + selfCSVReader = csvReader +} + +func getCSVReader(in io.Reader) CSVReader { + return selfCSVReader(in) +} + +// -------------------------------------------------------------------------- +// Marshal functions + +// MarshalFile saves the interface as CSV in the file. +func MarshalFile(in interface{}, file *os.File) (err error) { + return Marshal(in, file) +} + +// MarshalString returns the CSV string from the interface. +func MarshalString(in interface{}) (out string, err error) { + bufferString := bytes.NewBufferString(out) + if err := Marshal(in, bufferString); err != nil { + return "", err + } + return bufferString.String(), nil +} + +// MarshalBytes returns the CSV bytes from the interface. +func MarshalBytes(in interface{}) (out []byte, err error) { + bufferString := bytes.NewBuffer(out) + if err := Marshal(in, bufferString); err != nil { + return nil, err + } + return bufferString.Bytes(), nil +} + +// Marshal returns the CSV in writer from the interface. +func Marshal(in interface{}, out io.Writer) (err error) { + writer := getCSVWriter(out) + return writeTo(writer, in, false) +} + +// MarshalWithoutHeaders returns the CSV in writer from the interface. +func MarshalWithoutHeaders(in interface{}, out io.Writer) (err error) { + writer := getCSVWriter(out) + return writeTo(writer, in, true) +} + +// MarshalChan returns the CSV read from the channel. +func MarshalChan(c <-chan interface{}, out CSVWriter) error { + return writeFromChan(out, c) +} + +// MarshalCSV returns the CSV in writer from the interface. +func MarshalCSV(in interface{}, out CSVWriter) (err error) { + return writeTo(out, in, false) +} + +// MarshalCSVWithoutHeaders returns the CSV in writer from the interface. +func MarshalCSVWithoutHeaders(in interface{}, out CSVWriter) (err error) { + return writeTo(out, in, true) +} + +// -------------------------------------------------------------------------- +// Unmarshal functions + +// UnmarshalFile parses the CSV from the file in the interface. +func UnmarshalFile(in *os.File, out interface{}) error { + return Unmarshal(in, out) +} + +// UnmarshalFile parses the CSV from the file in the interface. +func UnmarshalFileWithErrorHandler(in *os.File, errHandler ErrorHandler, out interface{}) error { + return UnmarshalWithErrorHandler(in, errHandler, out) +} + +// UnmarshalString parses the CSV from the string in the interface. +func UnmarshalString(in string, out interface{}) error { + return Unmarshal(strings.NewReader(in), out) +} + +// UnmarshalBytes parses the CSV from the bytes in the interface. +func UnmarshalBytes(in []byte, out interface{}) error { + return Unmarshal(bytes.NewReader(in), out) +} + +// Unmarshal parses the CSV from the reader in the interface. +func Unmarshal(in io.Reader, out interface{}) error { + return readTo(newSimpleDecoderFromReader(in), out) +} + +// Unmarshal parses the CSV from the reader in the interface. +func UnmarshalWithErrorHandler(in io.Reader, errHandle ErrorHandler, out interface{}) error { + return readToWithErrorHandler(newSimpleDecoderFromReader(in), errHandle, out) +} + +// UnmarshalWithoutHeaders parses the CSV from the reader in the interface. +func UnmarshalWithoutHeaders(in io.Reader, out interface{}) error { + return readToWithoutHeaders(newSimpleDecoderFromReader(in), out) +} + +// UnmarshalCSVWithoutHeaders parses a headerless CSV with passed in CSV reader +func UnmarshalCSVWithoutHeaders(in CSVReader, out interface{}) error { + return readToWithoutHeaders(csvDecoder{in}, out) +} + +// UnmarshalDecoder parses the CSV from the decoder in the interface +func UnmarshalDecoder(in Decoder, out interface{}) error { + return readTo(in, out) +} + +// UnmarshalCSV parses the CSV from the reader in the interface. +func UnmarshalCSV(in CSVReader, out interface{}) error { + return readTo(csvDecoder{in}, out) +} + +// UnmarshalCSVToMap parses a CSV of 2 columns into a map. +func UnmarshalCSVToMap(in CSVReader, out interface{}) error { + decoder := NewSimpleDecoderFromCSVReader(in) + header, err := decoder.getCSVRow() + if err != nil { + return err + } + if len(header) != 2 { + return fmt.Errorf("maps can only be created for csv of two columns") + } + outValue, outType := getConcreteReflectValueAndType(out) + if outType.Kind() != reflect.Map { + return fmt.Errorf("cannot use " + outType.String() + ", only map supported") + } + keyType := outType.Key() + valueType := outType.Elem() + outValue.Set(reflect.MakeMap(outType)) + for { + key := reflect.New(keyType) + value := reflect.New(valueType) + line, err := decoder.getCSVRow() + if err == io.EOF { + break + } else if err != nil { + return err + } + if err := setField(key, line[0], false); err != nil { + return err + } + if err := setField(value, line[1], false); err != nil { + return err + } + outValue.SetMapIndex(key.Elem(), value.Elem()) + } + return nil +} + +// UnmarshalToChan parses the CSV from the reader and send each value in the chan c. +// The channel must have a concrete type. +func UnmarshalToChan(in io.Reader, c interface{}) error { + if c == nil { + return fmt.Errorf("goscv: channel is %v", c) + } + return readEach(newSimpleDecoderFromReader(in), c) +} + +// UnmarshalToChanWithoutHeaders parses the CSV from the reader and send each value in the chan c. +// The channel must have a concrete type. +func UnmarshalToChanWithoutHeaders(in io.Reader, c interface{}) error { + if c == nil { + return fmt.Errorf("goscv: channel is %v", c) + } + return readEachWithoutHeaders(newSimpleDecoderFromReader(in), c) +} + +// UnmarshalDecoderToChan parses the CSV from the decoder and send each value in the chan c. +// The channel must have a concrete type. +func UnmarshalDecoderToChan(in SimpleDecoder, c interface{}) error { + if c == nil { + return fmt.Errorf("goscv: channel is %v", c) + } + return readEach(in, c) +} + +// UnmarshalStringToChan parses the CSV from the string and send each value in the chan c. +// The channel must have a concrete type. +func UnmarshalStringToChan(in string, c interface{}) error { + return UnmarshalToChan(strings.NewReader(in), c) +} + +// UnmarshalBytesToChan parses the CSV from the bytes and send each value in the chan c. +// The channel must have a concrete type. +func UnmarshalBytesToChan(in []byte, c interface{}) error { + return UnmarshalToChan(bytes.NewReader(in), c) +} + +// UnmarshalToCallback parses the CSV from the reader and send each value to the given func f. +// The func must look like func(Struct). +func UnmarshalToCallback(in io.Reader, f interface{}) error { + valueFunc := reflect.ValueOf(f) + t := reflect.TypeOf(f) + if t.NumIn() != 1 { + return fmt.Errorf("the given function must have exactly one parameter") + } + cerr := make(chan error) + c := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, t.In(0)), 0) + go func() { + cerr <- UnmarshalToChan(in, c.Interface()) + }() + for { + select { + case err := <-cerr: + return err + default: + } + v, notClosed := c.Recv() + if !notClosed || v.Interface() == nil { + break + } + valueFunc.Call([]reflect.Value{v}) + } + return nil +} + +// UnmarshalDecoderToCallback parses the CSV from the decoder and send each value to the given func f. +// The func must look like func(Struct). +func UnmarshalDecoderToCallback(in SimpleDecoder, f interface{}) error { + valueFunc := reflect.ValueOf(f) + t := reflect.TypeOf(f) + if t.NumIn() != 1 { + return fmt.Errorf("the given function must have exactly one parameter") + } + cerr := make(chan error) + c := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, t.In(0)), 0) + go func() { + cerr <- UnmarshalDecoderToChan(in, c.Interface()) + }() + for { + select { + case err := <-cerr: + return err + default: + } + v, notClosed := c.Recv() + if !notClosed || v.Interface() == nil { + break + } + valueFunc.Call([]reflect.Value{v}) + } + return nil +} + +// UnmarshalBytesToCallback parses the CSV from the bytes and send each value to the given func f. +// The func must look like func(Struct). +func UnmarshalBytesToCallback(in []byte, f interface{}) error { + return UnmarshalToCallback(bytes.NewReader(in), f) +} + +// UnmarshalStringToCallback parses the CSV from the string and send each value to the given func f. +// The func must look like func(Struct). +func UnmarshalStringToCallback(in string, c interface{}) (err error) { + return UnmarshalToCallback(strings.NewReader(in), c) +} + +// UnmarshalToCallbackWithError parses the CSV from the reader and +// send each value to the given func f. +// +// If func returns error, it will stop processing, drain the +// parser and propagate the error to caller. +// +// The func must look like func(Struct) error. +func UnmarshalToCallbackWithError(in io.Reader, f interface{}) error { + valueFunc := reflect.ValueOf(f) + t := reflect.TypeOf(f) + if t.NumIn() != 1 { + return fmt.Errorf("the given function must have exactly one parameter") + } + if t.NumOut() != 1 { + return fmt.Errorf("the given function must have exactly one return value") + } + if !isErrorType(t.Out(0)) { + return fmt.Errorf("the given function must only return error.") + } + + cerr := make(chan error) + c := reflect.MakeChan(reflect.ChanOf(reflect.BothDir, t.In(0)), 0) + go func() { + cerr <- UnmarshalToChan(in, c.Interface()) + }() + + var fErr error + for { + select { + case err := <-cerr: + if err != nil { + return err + } + return fErr + default: + } + v, notClosed := c.Recv() + if !notClosed || v.Interface() == nil { + if err := <- cerr; err != nil { + fErr = err + } + break + } + + // callback f has already returned an error, stop processing but keep draining the chan c + if fErr != nil { + continue + } + + results := valueFunc.Call([]reflect.Value{v}) + + // If the callback f returns an error, stores it and returns it in future. + errValue := results[0].Interface() + if errValue != nil { + fErr = errValue.(error) + } + } + return fErr +} + +// UnmarshalBytesToCallbackWithError parses the CSV from the bytes and +// send each value to the given func f. +// +// If func returns error, it will stop processing, drain the +// parser and propagate the error to caller. +// +// The func must look like func(Struct) error. +func UnmarshalBytesToCallbackWithError(in []byte, f interface{}) error { + return UnmarshalToCallbackWithError(bytes.NewReader(in), f) +} + +// UnmarshalStringToCallbackWithError parses the CSV from the string and +// send each value to the given func f. +// +// If func returns error, it will stop processing, drain the +// parser and propagate the error to caller. +// +// The func must look like func(Struct) error. +func UnmarshalStringToCallbackWithError(in string, c interface{}) (err error) { + return UnmarshalToCallbackWithError(strings.NewReader(in), c) +} + +// CSVToMap creates a simple map from a CSV of 2 columns. +func CSVToMap(in io.Reader) (map[string]string, error) { + decoder := newSimpleDecoderFromReader(in) + header, err := decoder.getCSVRow() + if err != nil { + return nil, err + } + if len(header) != 2 { + return nil, fmt.Errorf("maps can only be created for csv of two columns") + } + m := make(map[string]string) + for { + line, err := decoder.getCSVRow() + if err == io.EOF { + break + } else if err != nil { + return nil, err + } + m[line[0]] = line[1] + } + return m, nil +} + +// CSVToMaps takes a reader and returns an array of dictionaries, using the header row as the keys +func CSVToMaps(reader io.Reader) ([]map[string]string, error) { + r := csv.NewReader(reader) + rows := []map[string]string{} + var header []string + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return nil, err + } + if header == nil { + header = record + } else { + dict := map[string]string{} + for i := range header { + dict[header[i]] = record[i] + } + rows = append(rows, dict) + } + } + return rows, nil +} + +// CSVToChanMaps parses the CSV from the reader and send a dictionary in the chan c, using the header row as the keys. +func CSVToChanMaps(reader io.Reader, c chan<- map[string]string) error { + r := csv.NewReader(reader) + var header []string + for { + record, err := r.Read() + if err == io.EOF { + break + } + if err != nil { + return err + } + if header == nil { + header = record + } else { + dict := map[string]string{} + for i := range header { + dict[header[i]] = record[i] + } + c <- dict + } + } + return nil +} diff --git a/src/vendor/github.com/gocarina/gocsv/decode.go b/src/vendor/github.com/gocarina/gocsv/decode.go new file mode 100644 index 000000000..ebaa97cde --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/decode.go @@ -0,0 +1,463 @@ +package gocsv + +import ( + "encoding/csv" + "errors" + "fmt" + "io" + "reflect" +) + +// Decoder . +type Decoder interface { + getCSVRows() ([][]string, error) +} + +// SimpleDecoder . +type SimpleDecoder interface { + getCSVRow() ([]string, error) + getCSVRows() ([][]string, error) +} + +type CSVReader interface { + Read() ([]string, error) + ReadAll() ([][]string, error) +} + +type csvDecoder struct { + CSVReader +} + +func newSimpleDecoderFromReader(r io.Reader) SimpleDecoder { + return csvDecoder{getCSVReader(r)} +} + +var ( + ErrEmptyCSVFile = errors.New("empty csv file given") + ErrNoStructTags = errors.New("no csv struct tags found") +) + +// NewSimpleDecoderFromCSVReader creates a SimpleDecoder, which may be passed +// to the UnmarshalDecoder* family of functions, from a CSV reader. Note that +// encoding/csv.Reader implements CSVReader, so you can pass one of those +// directly here. +func NewSimpleDecoderFromCSVReader(r CSVReader) SimpleDecoder { + return csvDecoder{r} +} + +func (c csvDecoder) getCSVRows() ([][]string, error) { + return c.ReadAll() +} + +func (c csvDecoder) getCSVRow() ([]string, error) { + return c.Read() +} + +func mismatchStructFields(structInfo []fieldInfo, headers []string) []string { + missing := make([]string, 0) + if len(structInfo) == 0 { + return missing + } + + headerMap := make(map[string]struct{}, len(headers)) + for idx := range headers { + headerMap[headers[idx]] = struct{}{} + } + + for _, info := range structInfo { + found := false + for _, key := range info.keys { + if _, ok := headerMap[key]; ok { + found = true + break + } + } + if !found { + missing = append(missing, info.keys...) + } + } + return missing +} + +func mismatchHeaderFields(structInfo []fieldInfo, headers []string) []string { + missing := make([]string, 0) + if len(headers) == 0 { + return missing + } + + keyMap := make(map[string]struct{}) + for _, info := range structInfo { + for _, key := range info.keys { + keyMap[key] = struct{}{} + } + } + + for _, header := range headers { + if _, ok := keyMap[header]; !ok { + missing = append(missing, header) + } + } + return missing +} + +func maybeMissingStructFields(structInfo []fieldInfo, headers []string) error { + missing := mismatchStructFields(structInfo, headers) + if len(missing) != 0 { + return fmt.Errorf("found unmatched struct field with tags %v", missing) + } + return nil +} + +// Check that no header name is repeated twice +func maybeDoubleHeaderNames(headers []string) error { + headerMap := make(map[string]bool, len(headers)) + for _, v := range headers { + if _, ok := headerMap[v]; ok { + return fmt.Errorf("repeated header name: %v", v) + } + headerMap[v] = true + } + return nil +} + +// apply normalizer func to headers +func normalizeHeaders(headers []string) []string { + out := make([]string, len(headers)) + for i, h := range headers { + out[i] = normalizeName(h) + } + return out +} + +func readTo(decoder Decoder, out interface{}) error { + return readToWithErrorHandler(decoder, nil, out) +} + +func readToWithErrorHandler(decoder Decoder, errHandler ErrorHandler, out interface{}) error { + outValue, outType := getConcreteReflectValueAndType(out) // Get the concrete type (not pointer) (Slice or Array) + if err := ensureOutType(outType); err != nil { + return err + } + outInnerWasPointer, outInnerType := getConcreteContainerInnerType(outType) // Get the concrete inner type (not pointer) (Container<"?">) + if err := ensureOutInnerType(outInnerType); err != nil { + return err + } + csvRows, err := decoder.getCSVRows() // Get the CSV csvRows + if err != nil { + return err + } + if len(csvRows) == 0 { + return ErrEmptyCSVFile + } + if err := ensureOutCapacity(&outValue, len(csvRows)); err != nil { // Ensure the container is big enough to hold the CSV content + return err + } + outInnerStructInfo := getStructInfo(outInnerType) // Get the inner struct info to get CSV annotations + if len(outInnerStructInfo.Fields) == 0 { + return ErrNoStructTags + } + + headers := normalizeHeaders(csvRows[0]) + body := csvRows[1:] + + csvHeadersLabels := make(map[int]*fieldInfo, len(outInnerStructInfo.Fields)) // Used to store the correspondance header <-> position in CSV + + headerCount := map[string]int{} + for i, csvColumnHeader := range headers { + curHeaderCount := headerCount[csvColumnHeader] + if fieldInfo := getCSVFieldPosition(csvColumnHeader, outInnerStructInfo, curHeaderCount); fieldInfo != nil { + csvHeadersLabels[i] = fieldInfo + if ShouldAlignDuplicateHeadersWithStructFieldOrder { + curHeaderCount++ + headerCount[csvColumnHeader] = curHeaderCount + } + } + } + + if FailIfUnmatchedStructTags { + if err := maybeMissingStructFields(outInnerStructInfo.Fields, headers); err != nil { + return err + } + } + if FailIfDoubleHeaderNames { + if err := maybeDoubleHeaderNames(headers); err != nil { + return err + } + } + + var withFieldsOK bool + var fieldTypeUnmarshallerWithKeys TypeUnmarshalCSVWithFields + + for i, csvRow := range body { + objectIface := reflect.New(outValue.Index(i).Type()).Interface() + outInner := createNewOutInner(outInnerWasPointer, outInnerType) + for j, csvColumnContent := range csvRow { + if fieldInfo, ok := csvHeadersLabels[j]; ok { // Position found accordingly to header name + + if outInner.CanInterface() { + fieldTypeUnmarshallerWithKeys, withFieldsOK = objectIface.(TypeUnmarshalCSVWithFields) + if withFieldsOK { + if err := fieldTypeUnmarshallerWithKeys.UnmarshalCSVWithFields(fieldInfo.getFirstKey(), csvColumnContent); err != nil { + parseError := csv.ParseError{ + Line: i + 2, //add 2 to account for the header & 0-indexing of arrays + Column: j + 1, + Err: err, + } + return &parseError + } + continue + } + } + value := csvColumnContent + if value == "" { + value = fieldInfo.defaultValue + } + if err := setInnerField(&outInner, outInnerWasPointer, fieldInfo.IndexChain, value, fieldInfo.omitEmpty); err != nil { // Set field of struct + parseError := csv.ParseError{ + Line: i + 2, //add 2 to account for the header & 0-indexing of arrays + Column: j + 1, + Err: err, + } + if errHandler == nil || !errHandler(&parseError) { + return &parseError + } + } + } + } + + if withFieldsOK { + reflectedObject := reflect.ValueOf(objectIface) + outInner = reflectedObject.Elem() + } + + outValue.Index(i).Set(outInner) + } + return nil +} + +func readEach(decoder SimpleDecoder, c interface{}) error { + outValue, outType := getConcreteReflectValueAndType(c) // Get the concrete type (not pointer) + if outType.Kind() != reflect.Chan { + return fmt.Errorf("cannot use %v with type %s, only channel supported", c, outType) + } + defer outValue.Close() + + headers, err := decoder.getCSVRow() + if err != nil { + return err + } + headers = normalizeHeaders(headers) + + outInnerWasPointer, outInnerType := getConcreteContainerInnerType(outType) // Get the concrete inner type (not pointer) (Container<"?">) + if err := ensureOutInnerType(outInnerType); err != nil { + return err + } + outInnerStructInfo := getStructInfo(outInnerType) // Get the inner struct info to get CSV annotations + if len(outInnerStructInfo.Fields) == 0 { + return ErrNoStructTags + } + csvHeadersLabels := make(map[int]*fieldInfo, len(outInnerStructInfo.Fields)) // Used to store the correspondance header <-> position in CSV + headerCount := map[string]int{} + for i, csvColumnHeader := range headers { + curHeaderCount := headerCount[csvColumnHeader] + if fieldInfo := getCSVFieldPosition(csvColumnHeader, outInnerStructInfo, curHeaderCount); fieldInfo != nil { + csvHeadersLabels[i] = fieldInfo + if ShouldAlignDuplicateHeadersWithStructFieldOrder { + curHeaderCount++ + headerCount[csvColumnHeader] = curHeaderCount + } + } + } + if err := maybeMissingStructFields(outInnerStructInfo.Fields, headers); err != nil { + if FailIfUnmatchedStructTags { + return err + } + } + if FailIfDoubleHeaderNames { + if err := maybeDoubleHeaderNames(headers); err != nil { + return err + } + } + i := 0 + for { + line, err := decoder.getCSVRow() + if err == io.EOF { + break + } else if err != nil { + return err + } + outInner := createNewOutInner(outInnerWasPointer, outInnerType) + for j, csvColumnContent := range line { + if fieldInfo, ok := csvHeadersLabels[j]; ok { // Position found accordingly to header name + if err := setInnerField(&outInner, outInnerWasPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct + return &csv.ParseError{ + Line: i + 2, //add 2 to account for the header & 0-indexing of arrays + Column: j + 1, + Err: err, + } + } + } + } + outValue.Send(outInner) + i++ + } + return nil +} + +func readEachWithoutHeaders(decoder SimpleDecoder, c interface{}) error { + outValue, outType := getConcreteReflectValueAndType(c) // Get the concrete type (not pointer) (Slice or Array) + if err := ensureOutType(outType); err != nil { + return err + } + defer outValue.Close() + + outInnerWasPointer, outInnerType := getConcreteContainerInnerType(outType) // Get the concrete inner type (not pointer) (Container<"?">) + if err := ensureOutInnerType(outInnerType); err != nil { + return err + } + outInnerStructInfo := getStructInfo(outInnerType) // Get the inner struct info to get CSV annotations + if len(outInnerStructInfo.Fields) == 0 { + return ErrNoStructTags + } + + i := 0 + for { + line, err := decoder.getCSVRow() + if err == io.EOF { + break + } else if err != nil { + return err + } + outInner := createNewOutInner(outInnerWasPointer, outInnerType) + for j, csvColumnContent := range line { + fieldInfo := outInnerStructInfo.Fields[j] + if err := setInnerField(&outInner, outInnerWasPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct + return &csv.ParseError{ + Line: i + 2, //add 2 to account for the header & 0-indexing of arrays + Column: j + 1, + Err: err, + } + } + } + outValue.Send(outInner) + i++ + } + return nil +} + +func readToWithoutHeaders(decoder Decoder, out interface{}) error { + outValue, outType := getConcreteReflectValueAndType(out) // Get the concrete type (not pointer) (Slice or Array) + if err := ensureOutType(outType); err != nil { + return err + } + outInnerWasPointer, outInnerType := getConcreteContainerInnerType(outType) // Get the concrete inner type (not pointer) (Container<"?">) + if err := ensureOutInnerType(outInnerType); err != nil { + return err + } + csvRows, err := decoder.getCSVRows() // Get the CSV csvRows + if err != nil { + return err + } + if len(csvRows) == 0 { + return ErrEmptyCSVFile + } + if err := ensureOutCapacity(&outValue, len(csvRows)+1); err != nil { // Ensure the container is big enough to hold the CSV content + return err + } + outInnerStructInfo := getStructInfo(outInnerType) // Get the inner struct info to get CSV annotations + if len(outInnerStructInfo.Fields) == 0 { + return ErrNoStructTags + } + + for i, csvRow := range csvRows { + outInner := createNewOutInner(outInnerWasPointer, outInnerType) + for j, csvColumnContent := range csvRow { + fieldInfo := outInnerStructInfo.Fields[j] + if err := setInnerField(&outInner, outInnerWasPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct + return &csv.ParseError{ + Line: i + 1, + Column: j + 1, + Err: err, + } + } + } + outValue.Index(i).Set(outInner) + } + + return nil +} + +// Check if the outType is an array or a slice +func ensureOutType(outType reflect.Type) error { + switch outType.Kind() { + case reflect.Slice: + fallthrough + case reflect.Chan: + fallthrough + case reflect.Array: + return nil + } + return fmt.Errorf("cannot use " + outType.String() + ", only slice or array supported") +} + +// Check if the outInnerType is of type struct +func ensureOutInnerType(outInnerType reflect.Type) error { + switch outInnerType.Kind() { + case reflect.Struct: + return nil + } + return fmt.Errorf("cannot use " + outInnerType.String() + ", only struct supported") +} + +func ensureOutCapacity(out *reflect.Value, csvLen int) error { + switch out.Kind() { + case reflect.Array: + if out.Len() < csvLen-1 { // Array is not big enough to hold the CSV content (arrays are not addressable) + return fmt.Errorf("array capacity problem: cannot store %d %s in %s", csvLen-1, out.Type().Elem().String(), out.Type().String()) + } + case reflect.Slice: + if !out.CanAddr() && out.Len() < csvLen-1 { // Slice is not big enough tho hold the CSV content and is not addressable + return fmt.Errorf("slice capacity problem and is not addressable (did you forget &?)") + } else if out.CanAddr() && out.Len() < csvLen-1 { + out.Set(reflect.MakeSlice(out.Type(), csvLen-1, csvLen-1)) // Slice is not big enough, so grows it + } + } + return nil +} + +func getCSVFieldPosition(key string, structInfo *structInfo, curHeaderCount int) *fieldInfo { + matchedFieldCount := 0 + for _, field := range structInfo.Fields { + if field.matchesKey(key) { + if matchedFieldCount >= curHeaderCount { + return &field + } + matchedFieldCount++ + } + } + return nil +} + +func createNewOutInner(outInnerWasPointer bool, outInnerType reflect.Type) reflect.Value { + if outInnerWasPointer { + return reflect.New(outInnerType) + } + return reflect.New(outInnerType).Elem() +} + +func setInnerField(outInner *reflect.Value, outInnerWasPointer bool, index []int, value string, omitEmpty bool) error { + oi := *outInner + if outInnerWasPointer { + // initialize nil pointer + if oi.IsNil() { + setField(oi, "", omitEmpty) + } + oi = outInner.Elem() + } + // because pointers can be nil need to recurse one index at a time and perform nil check + if len(index) > 1 { + nextField := oi.Field(index[0]) + return setInnerField(&nextField, nextField.Kind() == reflect.Ptr, index[1:], value, omitEmpty) + } + return setField(oi.FieldByIndex(index), value, omitEmpty) +} diff --git a/src/vendor/github.com/gocarina/gocsv/encode.go b/src/vendor/github.com/gocarina/gocsv/encode.go new file mode 100644 index 000000000..896df0e58 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/encode.go @@ -0,0 +1,147 @@ +package gocsv + +import ( + "fmt" + "io" + "reflect" +) + +type encoder struct { + out io.Writer +} + +func newEncoder(out io.Writer) *encoder { + return &encoder{out} +} + +func writeFromChan(writer CSVWriter, c <-chan interface{}) error { + // Get the first value. It wil determine the header structure. + firstValue, ok := <-c + if !ok { + return fmt.Errorf("channel is closed") + } + inValue, inType := getConcreteReflectValueAndType(firstValue) // Get the concrete type + if err := ensureStructOrPtr(inType); err != nil { + return err + } + inInnerWasPointer := inType.Kind() == reflect.Ptr + inInnerStructInfo := getStructInfo(inType) // Get the inner struct info to get CSV annotations + csvHeadersLabels := make([]string, len(inInnerStructInfo.Fields)) + for i, fieldInfo := range inInnerStructInfo.Fields { // Used to write the header (first line) in CSV + csvHeadersLabels[i] = fieldInfo.getFirstKey() + } + if err := writer.Write(csvHeadersLabels); err != nil { + return err + } + write := func(val reflect.Value) error { + for j, fieldInfo := range inInnerStructInfo.Fields { + csvHeadersLabels[j] = "" + inInnerFieldValue, err := getInnerField(val, inInnerWasPointer, fieldInfo.IndexChain) // Get the correct field header <-> position + if err != nil { + return err + } + csvHeadersLabels[j] = inInnerFieldValue + } + if err := writer.Write(csvHeadersLabels); err != nil { + return err + } + return nil + } + if err := write(inValue); err != nil { + return err + } + for v := range c { + val, _ := getConcreteReflectValueAndType(v) // Get the concrete type (not pointer) (Slice or Array) + if err := ensureStructOrPtr(inType); err != nil { + return err + } + if err := write(val); err != nil { + return err + } + } + writer.Flush() + return writer.Error() +} + +func writeTo(writer CSVWriter, in interface{}, omitHeaders bool) error { + inValue, inType := getConcreteReflectValueAndType(in) // Get the concrete type (not pointer) (Slice or Array) + if err := ensureInType(inType); err != nil { + return err + } + inInnerWasPointer, inInnerType := getConcreteContainerInnerType(inType) // Get the concrete inner type (not pointer) (Container<"?">) + if err := ensureInInnerType(inInnerType); err != nil { + return err + } + inInnerStructInfo := getStructInfo(inInnerType) // Get the inner struct info to get CSV annotations + csvHeadersLabels := make([]string, len(inInnerStructInfo.Fields)) + for i, fieldInfo := range inInnerStructInfo.Fields { // Used to write the header (first line) in CSV + csvHeadersLabels[i] = fieldInfo.getFirstKey() + } + if !omitHeaders { + if err := writer.Write(csvHeadersLabels); err != nil { + return err + } + } + inLen := inValue.Len() + for i := 0; i < inLen; i++ { // Iterate over container rows + for j, fieldInfo := range inInnerStructInfo.Fields { + csvHeadersLabels[j] = "" + inInnerFieldValue, err := getInnerField(inValue.Index(i), inInnerWasPointer, fieldInfo.IndexChain) // Get the correct field header <-> position + if err != nil { + return err + } + csvHeadersLabels[j] = inInnerFieldValue + } + if err := writer.Write(csvHeadersLabels); err != nil { + return err + } + } + writer.Flush() + return writer.Error() +} + +func ensureStructOrPtr(t reflect.Type) error { + switch t.Kind() { + case reflect.Struct: + fallthrough + case reflect.Ptr: + return nil + } + return fmt.Errorf("cannot use " + t.String() + ", only slice or array supported") +} + +// Check if the inType is an array or a slice +func ensureInType(outType reflect.Type) error { + switch outType.Kind() { + case reflect.Slice: + fallthrough + case reflect.Array: + return nil + } + return fmt.Errorf("cannot use " + outType.String() + ", only slice or array supported") +} + +// Check if the inInnerType is of type struct +func ensureInInnerType(outInnerType reflect.Type) error { + switch outInnerType.Kind() { + case reflect.Struct: + return nil + } + return fmt.Errorf("cannot use " + outInnerType.String() + ", only struct supported") +} + +func getInnerField(outInner reflect.Value, outInnerWasPointer bool, index []int) (string, error) { + oi := outInner + if outInnerWasPointer { + if oi.IsNil() { + return "", nil + } + oi = outInner.Elem() + } + // because pointers can be nil need to recurse one index at a time and perform nil check + if len(index) > 1 { + nextField := oi.Field(index[0]) + return getInnerField(nextField, nextField.Kind() == reflect.Ptr, index[1:]) + } + return getFieldAsString(oi.FieldByIndex(index)) +} diff --git a/src/vendor/github.com/gocarina/gocsv/go.mod b/src/vendor/github.com/gocarina/gocsv/go.mod new file mode 100644 index 000000000..c746a5a05 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/go.mod @@ -0,0 +1,3 @@ +module github.com/gocarina/gocsv + +go 1.13 diff --git a/src/vendor/github.com/gocarina/gocsv/reflect.go b/src/vendor/github.com/gocarina/gocsv/reflect.go new file mode 100644 index 000000000..9ab5c4a37 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/reflect.go @@ -0,0 +1,143 @@ +package gocsv + +import ( + "reflect" + "strings" + "sync" +) + +// -------------------------------------------------------------------------- +// Reflection helpers + +type structInfo struct { + Fields []fieldInfo +} + +// fieldInfo is a struct field that should be mapped to a CSV column, or vice-versa +// Each IndexChain element before the last is the index of an the embedded struct field +// that defines Key as a tag +type fieldInfo struct { + keys []string + omitEmpty bool + IndexChain []int + defaultValue string +} + +func (f fieldInfo) getFirstKey() string { + return f.keys[0] +} + +func (f fieldInfo) matchesKey(key string) bool { + for _, k := range f.keys { + if key == k || strings.TrimSpace(key) == k { + return true + } + } + return false +} + +var structInfoCache sync.Map +var structMap = make(map[reflect.Type]*structInfo) +var structMapMutex sync.RWMutex + +func getStructInfo(rType reflect.Type) *structInfo { + stInfo, ok := structInfoCache.Load(rType) + if ok { + return stInfo.(*structInfo) + } + + fieldsList := getFieldInfos(rType, []int{}) + stInfo = &structInfo{fieldsList} + structInfoCache.Store(rType, stInfo) + + return stInfo.(*structInfo) +} + +func getFieldInfos(rType reflect.Type, parentIndexChain []int) []fieldInfo { + fieldsCount := rType.NumField() + fieldsList := make([]fieldInfo, 0, fieldsCount) + for i := 0; i < fieldsCount; i++ { + field := rType.Field(i) + if field.PkgPath != "" { + continue + } + + var cpy = make([]int, len(parentIndexChain)) + copy(cpy, parentIndexChain) + indexChain := append(cpy, i) + + // if the field is a pointer to a struct, follow the pointer then create fieldinfo for each field + if field.Type.Kind() == reflect.Ptr && field.Type.Elem().Kind() == reflect.Struct { + // unless it implements marshalText or marshalCSV. Structs that implement this + // should result in one value and not have their fields exposed + if !(canMarshal(field.Type.Elem())) { + fieldsList = append(fieldsList, getFieldInfos(field.Type.Elem(), indexChain)...) + } + } + // if the field is a struct, create a fieldInfo for each of its fields + if field.Type.Kind() == reflect.Struct { + // unless it implements marshalText or marshalCSV. Structs that implement this + // should result in one value and not have their fields exposed + if !(canMarshal(field.Type)) { + fieldsList = append(fieldsList, getFieldInfos(field.Type, indexChain)...) + } + } + + // if the field is an embedded struct, ignore the csv tag + if field.Anonymous { + continue + } + + fieldInfo := fieldInfo{IndexChain: indexChain} + fieldTag := field.Tag.Get(TagName) + fieldTags := strings.Split(fieldTag, TagSeparator) + filteredTags := []string{} + for _, fieldTagEntry := range fieldTags { + if fieldTagEntry == "omitempty" { + fieldInfo.omitEmpty = true + } else if strings.HasPrefix(fieldTagEntry, "default=") { + fieldInfo.defaultValue = strings.TrimPrefix(fieldTagEntry, "default=") + } else { + filteredTags = append(filteredTags, normalizeName(fieldTagEntry)) + } + } + + if len(filteredTags) == 1 && filteredTags[0] == "-" { + continue + } else if len(filteredTags) > 0 && filteredTags[0] != "" { + fieldInfo.keys = filteredTags + } else { + fieldInfo.keys = []string{normalizeName(field.Name)} + } + fieldsList = append(fieldsList, fieldInfo) + } + return fieldsList +} + +func getConcreteContainerInnerType(in reflect.Type) (inInnerWasPointer bool, inInnerType reflect.Type) { + inInnerType = in.Elem() + inInnerWasPointer = false + if inInnerType.Kind() == reflect.Ptr { + inInnerWasPointer = true + inInnerType = inInnerType.Elem() + } + return inInnerWasPointer, inInnerType +} + +func getConcreteReflectValueAndType(in interface{}) (reflect.Value, reflect.Type) { + value := reflect.ValueOf(in) + if value.Kind() == reflect.Ptr { + value = value.Elem() + } + return value, value.Type() +} + +var errorInterface = reflect.TypeOf((*error)(nil)).Elem() + +func isErrorType(outType reflect.Type) bool { + if outType.Kind() != reflect.Interface { + return false + } + + return outType.Implements(errorInterface) +} diff --git a/src/vendor/github.com/gocarina/gocsv/safe_csv.go b/src/vendor/github.com/gocarina/gocsv/safe_csv.go new file mode 100644 index 000000000..858b07816 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/safe_csv.go @@ -0,0 +1,38 @@ +package gocsv + +//Wraps around SafeCSVWriter and makes it thread safe. +import ( + "encoding/csv" + "sync" +) + +type CSVWriter interface { + Write(row []string) error + Flush() + Error() error +} + +type SafeCSVWriter struct { + *csv.Writer + m sync.Mutex +} + +func NewSafeCSVWriter(original *csv.Writer) *SafeCSVWriter { + return &SafeCSVWriter{ + Writer: original, + } +} + +//Override write +func (w *SafeCSVWriter) Write(row []string) error { + w.m.Lock() + defer w.m.Unlock() + return w.Writer.Write(row) +} + +//Override flush +func (w *SafeCSVWriter) Flush() { + w.m.Lock() + w.Writer.Flush() + w.m.Unlock() +} diff --git a/src/vendor/github.com/gocarina/gocsv/types.go b/src/vendor/github.com/gocarina/gocsv/types.go new file mode 100644 index 000000000..5c32d36b6 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/types.go @@ -0,0 +1,456 @@ +package gocsv + +import ( + "encoding" + "fmt" + "reflect" + "strconv" + "strings" + + "encoding/json" +) + +// -------------------------------------------------------------------------- +// Conversion interfaces + +// TypeMarshaller is implemented by any value that has a MarshalCSV method +// This converter is used to convert the value to it string representation +type TypeMarshaller interface { + MarshalCSV() (string, error) +} + +// TypeUnmarshaller is implemented by any value that has an UnmarshalCSV method +// This converter is used to convert a string to your value representation of that string +type TypeUnmarshaller interface { + UnmarshalCSV(string) error +} + +// TypeUnmarshalCSVWithFields can be implemented on whole structs to allow for whole structures to customized internal vs one off fields +type TypeUnmarshalCSVWithFields interface { + UnmarshalCSVWithFields(key, value string) error +} + +// NoUnmarshalFuncError is the custom error type to be raised in case there is no unmarshal function defined on type +type NoUnmarshalFuncError struct { + msg string +} + +func (e NoUnmarshalFuncError) Error() string { + return e.msg +} + +// NoMarshalFuncError is the custom error type to be raised in case there is no marshal function defined on type +type NoMarshalFuncError struct { + ty reflect.Type +} + +func (e NoMarshalFuncError) Error() string { + return "No known conversion from " + e.ty.String() + " to string, " + e.ty.String() + " does not implement TypeMarshaller nor Stringer" +} + +// -------------------------------------------------------------------------- +// Conversion helpers + +func toString(in interface{}) (string, error) { + inValue := reflect.ValueOf(in) + + switch inValue.Kind() { + case reflect.String: + return inValue.String(), nil + case reflect.Bool: + b := inValue.Bool() + if b { + return "true", nil + } + return "false", nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return fmt.Sprintf("%v", inValue.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return fmt.Sprintf("%v", inValue.Uint()), nil + case reflect.Float32: + return strconv.FormatFloat(inValue.Float(), byte('f'), -1, 32), nil + case reflect.Float64: + return strconv.FormatFloat(inValue.Float(), byte('f'), -1, 64), nil + } + return "", fmt.Errorf("No known conversion from " + inValue.Type().String() + " to string") +} + +func toBool(in interface{}) (bool, error) { + inValue := reflect.ValueOf(in) + + switch inValue.Kind() { + case reflect.String: + s := inValue.String() + s = strings.TrimSpace(s) + if strings.EqualFold(s, "yes") { + return true, nil + } else if strings.EqualFold(s, "no") || s == "" { + return false, nil + } else { + return strconv.ParseBool(s) + } + case reflect.Bool: + return inValue.Bool(), nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i := inValue.Int() + if i != 0 { + return true, nil + } + return false, nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + i := inValue.Uint() + if i != 0 { + return true, nil + } + return false, nil + case reflect.Float32, reflect.Float64: + f := inValue.Float() + if f != 0 { + return true, nil + } + return false, nil + } + return false, fmt.Errorf("No known conversion from " + inValue.Type().String() + " to bool") +} + +func toInt(in interface{}) (int64, error) { + inValue := reflect.ValueOf(in) + + switch inValue.Kind() { + case reflect.String: + s := strings.TrimSpace(inValue.String()) + if s == "" { + return 0, nil + } + out := strings.SplitN(s, ".", 2) + return strconv.ParseInt(out[0], 0, 64) + case reflect.Bool: + if inValue.Bool() { + return 1, nil + } + return 0, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return inValue.Int(), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return int64(inValue.Uint()), nil + case reflect.Float32, reflect.Float64: + return int64(inValue.Float()), nil + } + return 0, fmt.Errorf("No known conversion from " + inValue.Type().String() + " to int") +} + +func toUint(in interface{}) (uint64, error) { + inValue := reflect.ValueOf(in) + + switch inValue.Kind() { + case reflect.String: + s := strings.TrimSpace(inValue.String()) + if s == "" { + return 0, nil + } + + // support the float input + if strings.Contains(s, ".") { + f, err := strconv.ParseFloat(s, 64) + if err != nil { + return 0, err + } + return uint64(f), nil + } + return strconv.ParseUint(s, 0, 64) + case reflect.Bool: + if inValue.Bool() { + return 1, nil + } + return 0, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return uint64(inValue.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return inValue.Uint(), nil + case reflect.Float32, reflect.Float64: + return uint64(inValue.Float()), nil + } + return 0, fmt.Errorf("No known conversion from " + inValue.Type().String() + " to uint") +} + +func toFloat(in interface{}) (float64, error) { + inValue := reflect.ValueOf(in) + + switch inValue.Kind() { + case reflect.String: + s := strings.TrimSpace(inValue.String()) + if s == "" { + return 0, nil + } + return strconv.ParseFloat(s, 64) + case reflect.Bool: + if inValue.Bool() { + return 1, nil + } + return 0, nil + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return float64(inValue.Int()), nil + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + return float64(inValue.Uint()), nil + case reflect.Float32, reflect.Float64: + return inValue.Float(), nil + } + return 0, fmt.Errorf("No known conversion from " + inValue.Type().String() + " to float") +} + +func setField(field reflect.Value, value string, omitEmpty bool) error { + if field.Kind() == reflect.Ptr { + if omitEmpty && value == "" { + return nil + } + if field.IsNil() { + field.Set(reflect.New(field.Type().Elem())) + } + field = field.Elem() + } + + switch field.Interface().(type) { + case string: + s, err := toString(value) + if err != nil { + return err + } + field.SetString(s) + case bool: + b, err := toBool(value) + if err != nil { + return err + } + field.SetBool(b) + case int, int8, int16, int32, int64: + i, err := toInt(value) + if err != nil { + return err + } + field.SetInt(i) + case uint, uint8, uint16, uint32, uint64: + ui, err := toUint(value) + if err != nil { + return err + } + field.SetUint(ui) + case float32, float64: + f, err := toFloat(value) + if err != nil { + return err + } + field.SetFloat(f) + default: + // Not a native type, check for unmarshal method + if err := unmarshall(field, value); err != nil { + if _, ok := err.(NoUnmarshalFuncError); !ok { + return err + } + // Could not unmarshal, check for kind, e.g. renamed type from basic type + switch field.Kind() { + case reflect.String: + s, err := toString(value) + if err != nil { + return err + } + field.SetString(s) + case reflect.Bool: + b, err := toBool(value) + if err != nil { + return err + } + field.SetBool(b) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + i, err := toInt(value) + if err != nil { + return err + } + field.SetInt(i) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + ui, err := toUint(value) + if err != nil { + return err + } + field.SetUint(ui) + case reflect.Float32, reflect.Float64: + f, err := toFloat(value) + if err != nil { + return err + } + field.SetFloat(f) + case reflect.Slice, reflect.Struct: + err := json.Unmarshal([]byte(value), field.Addr().Interface()) + if err != nil { + return err + } + default: + return err + } + } else { + return nil + } + } + return nil +} + +func getFieldAsString(field reflect.Value) (str string, err error) { + switch field.Kind() { + case reflect.Interface, reflect.Ptr: + if field.IsNil() { + return "", nil + } + return getFieldAsString(field.Elem()) + default: + // Check if field is go native type + switch field.Interface().(type) { + case string: + return field.String(), nil + case bool: + if field.Bool() { + return "true", nil + } else { + return "false", nil + } + case int, int8, int16, int32, int64: + return fmt.Sprintf("%v", field.Int()), nil + case uint, uint8, uint16, uint32, uint64: + return fmt.Sprintf("%v", field.Uint()), nil + case float32: + str, err = toString(float32(field.Float())) + if err != nil { + return str, err + } + case float64: + str, err = toString(field.Float()) + if err != nil { + return str, err + } + default: + // Not a native type, check for marshal method + str, err = marshall(field) + if err != nil { + if _, ok := err.(NoMarshalFuncError); !ok { + return str, err + } + // If not marshal method, is field compatible with/renamed from native type + switch field.Kind() { + case reflect.String: + return field.String(), nil + case reflect.Bool: + str, err = toString(field.Bool()) + if err != nil { + return str, err + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + str, err = toString(field.Int()) + if err != nil { + return str, err + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: + str, err = toString(field.Uint()) + if err != nil { + return str, err + } + case reflect.Float32: + str, err = toString(float32(field.Float())) + if err != nil { + return str, err + } + case reflect.Float64: + str, err = toString(field.Float()) + if err != nil { + return str, err + } + } + } else { + return str, nil + } + } + } + return str, nil +} + +// -------------------------------------------------------------------------- +// Un/serializations helpers + +func canMarshal(t reflect.Type) bool { + // unless it implements marshalText or marshalCSV. Structs that implement this + // should result in one value and not have their fields exposed + _, canMarshalText := t.MethodByName("MarshalText") + _, canMarshalCSV := t.MethodByName("MarshalCSV") + return canMarshalCSV || canMarshalText +} + +func unmarshall(field reflect.Value, value string) error { + dupField := field + unMarshallIt := func(finalField reflect.Value) error { + if finalField.CanInterface() { + fieldIface := finalField.Interface() + + fieldTypeUnmarshaller, ok := fieldIface.(TypeUnmarshaller) + if ok { + return fieldTypeUnmarshaller.UnmarshalCSV(value) + } + + // Otherwise try to use TextUnmarshaler + fieldTextUnmarshaler, ok := fieldIface.(encoding.TextUnmarshaler) + if ok { + return fieldTextUnmarshaler.UnmarshalText([]byte(value)) + } + } + + return NoUnmarshalFuncError{"No known conversion from string to " + field.Type().String() + ", " + field.Type().String() + " does not implement TypeUnmarshaller"} + } + for dupField.Kind() == reflect.Interface || dupField.Kind() == reflect.Ptr { + if dupField.IsNil() { + dupField = reflect.New(field.Type().Elem()) + field.Set(dupField) + return unMarshallIt(dupField) + } + dupField = dupField.Elem() + } + if dupField.CanAddr() { + return unMarshallIt(dupField.Addr()) + } + return NoUnmarshalFuncError{"No known conversion from string to " + field.Type().String() + ", " + field.Type().String() + " does not implement TypeUnmarshaller"} +} + +func marshall(field reflect.Value) (value string, err error) { + dupField := field + marshallIt := func(finalField reflect.Value) (string, error) { + if finalField.CanInterface() { + fieldIface := finalField.Interface() + + // Use TypeMarshaller when possible + fieldTypeMarhaller, ok := fieldIface.(TypeMarshaller) + if ok { + return fieldTypeMarhaller.MarshalCSV() + } + + // Otherwise try to use TextMarshaller + fieldTextMarshaler, ok := fieldIface.(encoding.TextMarshaler) + if ok { + text, err := fieldTextMarshaler.MarshalText() + return string(text), err + } + + // Otherwise try to use Stringer + fieldStringer, ok := fieldIface.(fmt.Stringer) + if ok { + return fieldStringer.String(), nil + } + } + + return value, NoMarshalFuncError{field.Type()} + } + for dupField.Kind() == reflect.Interface || dupField.Kind() == reflect.Ptr { + if dupField.IsNil() { + return value, nil + } + dupField = dupField.Elem() + } + if dupField.CanAddr() { + dupField = dupField.Addr() + } + return marshallIt(dupField) +} diff --git a/src/vendor/github.com/gocarina/gocsv/unmarshaller.go b/src/vendor/github.com/gocarina/gocsv/unmarshaller.go new file mode 100644 index 000000000..8a31a7501 --- /dev/null +++ b/src/vendor/github.com/gocarina/gocsv/unmarshaller.go @@ -0,0 +1,117 @@ +package gocsv + +import ( + "encoding/csv" + "fmt" + "reflect" +) + +// Unmarshaller is a CSV to struct unmarshaller. +type Unmarshaller struct { + reader *csv.Reader + Headers []string + fieldInfoMap []*fieldInfo + MismatchedHeaders []string + MismatchedStructFields []string + outType reflect.Type +} + +// NewUnmarshaller creates an unmarshaller from a csv.Reader and a struct. +func NewUnmarshaller(reader *csv.Reader, out interface{}) (*Unmarshaller, error) { + headers, err := reader.Read() + if err != nil { + return nil, err + } + headers = normalizeHeaders(headers) + + um := &Unmarshaller{reader: reader, outType: reflect.TypeOf(out)} + err = validate(um, out, headers) + if err != nil { + return nil, err + } + return um, nil +} + +// Read returns an interface{} whose runtime type is the same as the struct that +// was used to create the Unmarshaller. +func (um *Unmarshaller) Read() (interface{}, error) { + row, err := um.reader.Read() + if err != nil { + return nil, err + } + return um.unmarshalRow(row, nil) +} + +// ReadUnmatched is same as Read(), but returns a map of the columns that didn't match a field in the struct +func (um *Unmarshaller) ReadUnmatched() (interface{}, map[string]string, error) { + row, err := um.reader.Read() + if err != nil { + return nil, nil, err + } + unmatched := make(map[string]string) + value, err := um.unmarshalRow(row, unmatched) + return value, unmatched, err +} + +// validate ensures that a struct was used to create the Unmarshaller, and validates +// CSV headers against the CSV tags in the struct. +func validate(um *Unmarshaller, s interface{}, headers []string) error { + concreteType := reflect.TypeOf(s) + if concreteType.Kind() == reflect.Ptr { + concreteType = concreteType.Elem() + } + if err := ensureOutInnerType(concreteType); err != nil { + return err + } + structInfo := getStructInfo(concreteType) // Get struct info to get CSV annotations. + if len(structInfo.Fields) == 0 { + return ErrNoStructTags + } + csvHeadersLabels := make([]*fieldInfo, len(headers)) // Used to store the corresponding header <-> position in CSV + headerCount := map[string]int{} + for i, csvColumnHeader := range headers { + curHeaderCount := headerCount[csvColumnHeader] + if fieldInfo := getCSVFieldPosition(csvColumnHeader, structInfo, curHeaderCount); fieldInfo != nil { + csvHeadersLabels[i] = fieldInfo + if ShouldAlignDuplicateHeadersWithStructFieldOrder { + curHeaderCount++ + headerCount[csvColumnHeader] = curHeaderCount + } + } + } + + if FailIfDoubleHeaderNames { + if err := maybeDoubleHeaderNames(headers); err != nil { + return err + } + } + + um.Headers = headers + um.fieldInfoMap = csvHeadersLabels + um.MismatchedHeaders = mismatchHeaderFields(structInfo.Fields, headers) + um.MismatchedStructFields = mismatchStructFields(structInfo.Fields, headers) + return nil +} + +// unmarshalRow converts a CSV row to a struct, based on CSV struct tags. +// If unmatched is non nil, it is populated with any columns that don't map to a struct field +func (um *Unmarshaller) unmarshalRow(row []string, unmatched map[string]string) (interface{}, error) { + isPointer := false + concreteOutType := um.outType + if um.outType.Kind() == reflect.Ptr { + isPointer = true + concreteOutType = concreteOutType.Elem() + } + outValue := createNewOutInner(isPointer, concreteOutType) + for j, csvColumnContent := range row { + if j < len(um.fieldInfoMap) && um.fieldInfoMap[j] != nil { + fieldInfo := um.fieldInfoMap[j] + if err := setInnerField(&outValue, isPointer, fieldInfo.IndexChain, csvColumnContent, fieldInfo.omitEmpty); err != nil { // Set field of struct + return nil, fmt.Errorf("cannot assign field at %v to %s through index chain %v: %v", j, outValue.Type(), fieldInfo.IndexChain, err) + } + } else if unmatched != nil { + unmatched[um.Headers[j]] = csvColumnContent + } + } + return outValue.Interface(), nil +} diff --git a/src/vendor/modules.txt b/src/vendor/modules.txt index 2779f644b..582bedadc 100644 --- a/src/vendor/modules.txt +++ b/src/vendor/modules.txt @@ -399,6 +399,9 @@ github.com/go-sql-driver/mysql # github.com/go-stack/stack v1.8.0 ## explicit github.com/go-stack/stack +# github.com/gocarina/gocsv v0.0.0-20210516172204-ca9e8a8ddea8 +## explicit +github.com/gocarina/gocsv # github.com/gocraft/work v0.5.1 ## explicit github.com/gocraft/work