Claran's blog

怎么能成为Go学长,不行不行!(※也不是不行?!)

Go语言优雅错误处理指南

概述

本文档详细讲解如何在Go语言项目中实现优雅的错误处理机制,特别是在Gin框架的Handler-Service-DAO分层架构中。

核心原则

1. 错误是值,不是异常

Go语言将错误视为普通返回值,这要求开发者显式处理每个可能的错误。

2. 添加上下文信息

错误在传递过程中应该携带足够的上下文信息,便于问题定位。

3. 统一错误响应

API应该返回统一格式的错误响应,方便客户端处理。

分层错误处理架构

项目结构

1
2
3
4
5
6
7
8
9
10
11
12
13
project/
├── response/
│ └── response.go
├── errors/
│ └── business.go
├── handler/
│ └── user_handler.go
├── service/
│ └── user_service.go
├── dao/
│ └── user_dao.go
└── model/
└── user.go

为什么需要分层错误处理?

各层职责和错误处理策略

层级 职责 错误处理策略 为什么这样设计
DAO层 数据访问,纯技术操作 返回原始错误或基础业务错误 DAO层不应该关心业务逻辑,只负责技术错误
Service层 业务逻辑处理 将技术错误转换为业务错误,添加业务上下文 Service层理解业务含义,知道如何包装错误
Handler层 HTTP请求处理 捕获所有错误,转换为HTTP响应 Handler层是系统边界,需要统一响应格式

错误传递的哲学

  1. DAO层保持纯粹:只处理数据访问相关错误,不添加业务语义
  2. Service层添加业务语义:将技术错误翻译成业务人员能理解的错误
  3. Handler层统一格式化:将错误转换为客户端能理解的格式

源码实现

1. 统一响应包 (response/response.go)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
package response

import (
"net/http"

"github.com/gin-gonic/gin"
)

type Response struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data,omitempty"`
Error string `json:"error,omitempty"`
}

func Success(c *gin.Context, data interface{}) {
c.JSON(http.StatusOK, Response{
Code: 0,
Message: "success",
Data: data,
})
}

func Error(c *gin.Context, code int, message string) {
c.JSON(http.StatusOK, Response{
Code: code,
Message: message,
})
}

func InternalError(c *gin.Context, err error) {
c.Error(err)
c.JSON(http.StatusOK, Response{
Code: 500,
Message: "内部服务器错误",
})
}

func BadRequest(c *gin.Context, message string) {
c.JSON(http.StatusOK, Response{
Code: 400,
Message: message,
})
}

func NotFound(c *gin.Context, message string) {
c.JSON(http.StatusOK, Response{
Code: 404,
Message: message,
})
}

设计理由:统一响应格式确保客户端始终收到结构一致的响应,便于错误处理和用户体验优化。

2. 业务错误定义 (errors/business.go)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
package errors

import "fmt"

const (
CodeUserNotFound = 10001
CodeInvalidParam = 10002
CodeDBError = 10003
)

type BusinessError struct {
Code int
Message string
Err error
}

func (e *BusinessError) Error() string {
if e.Err != nil {
return fmt.Sprintf("业务错误[%d]: %s, 原因: %v", e.Code, e.Message, e.Err)
}
return fmt.Sprintf("业务错误[%d]: %s", e.Code, e.Message)
}

func (e *BusinessError) Unwrap() error {
return e.Err
}

func NewBusinessError(code int, message string) *BusinessError {
return &BusinessError{
Code: code,
Message: message,
}
}

func WrapBusinessError(code int, message string, err error) *BusinessError {
return &BusinessError{
Code: code,
Message: message,
Err: err,
}
}

var (
ErrUserNotFound = NewBusinessError(CodeUserNotFound, "用户不存在")
ErrInvalidParam = NewBusinessError(CodeInvalidParam, "参数错误")
)

设计理由:自定义错误类型可以携带丰富的上下文信息(错误码、消息、原始错误),支持错误链追踪。

3. 数据模型 (model/user.go)

1
2
3
4
5
6
7
8
package model

type User struct {
ID int64 `json:"id"`
Name string `json:"name"`
Email string `json:"email"`
Status string `json:"status"`
}

4. DAO层 (dao/user_dao.go)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
package dao

import (
"database/sql"
"fmt"

"your-project/errors"
"your-project/model"
)

type UserDAO struct {
db *sql.DB
}

func NewUserDAO(db *sql.DB) *UserDAO {
return &UserDAO{db: db}
}

func (d *UserDAO) GetUserByID(userID int64) (*model.User, error) {
const query = "SELECT id, name, email, status FROM users WHERE id = ? AND deleted = 0"

var user model.User
err := d.db.QueryRow(query, userID).Scan(&user.ID, &user.Name, &user.Email, &user.Status)

if err != nil {
if err == sql.ErrNoRows {
return nil, errors.ErrUserNotFound
}
return nil, fmt.Errorf("查询用户失败 (ID: %d): %w", userID, err)
}

return &user, nil
}

func (d *UserDAO) CreateUser(user *model.User) error {
const query = "INSERT INTO users (name, email) VALUES (?, ?)"

result, err := d.db.Exec(query, user.Name, user.Email)
if err != nil {
return fmt.Errorf("创建用户失败: %w", err)
}

userID, err := result.LastInsertId()
if err != nil {
return fmt.Errorf("获取用户ID失败: %w", err)
}

user.ID = userID
return nil
}

DAO层错误传递策略

  • 遇到sql.ErrNoRows时返回业务错误ErrUserNotFound,因为”用户不存在”是业务逻辑的一部分
  • 其他数据库错误使用%w包装,保留原始错误信息但添加上下文
  • 不直接返回HTTP状态码,因为DAO层不应该知道HTTP协议

5. Service层 (service/user_service.go)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
package service

import (
"your-project/dao"
"your-project/errors"
"your-project/model"
)

type UserService struct {
userDAO *dao.UserDAO
}

func NewUserService(userDAO *dao.UserDAO) *UserService {
return &UserService{userDAO: userDAO}
}

type CreateUserRequest struct {
Name string `json:"name" binding:"required"`
Email string `json:"email" binding:"required,email"`
}

func (s *UserService) GetUser(userID int64) (*model.User, error) {
if userID <= 0 {
return nil, errors.WrapBusinessError(
errors.CodeInvalidParam,
"用户ID必须大于0",
nil,
)
}

user, err := s.userDAO.GetUserByID(userID)
if err != nil {
var businessErr *errors.BusinessError
if errors.As(err, &businessErr) {
return nil, err
}

return nil, errors.WrapBusinessError(
errors.CodeDBError,
"获取用户信息失败",
err,
)
}

if user.Status == "banned" {
return nil, errors.NewBusinessError(10004, "用户已被封禁")
}

return user, nil
}

func (s *UserService) CreateUser(req *CreateUserRequest) (*model.User, error) {
if req.Name == "" {
return nil, errors.WrapBusinessError(
errors.CodeInvalidParam,
"用户名不能为空",
nil,
)
}

if len(req.Name) < 2 || len(req.Name) > 20 {
return nil, errors.WrapBusinessError(
errors.CodeInvalidParam,
"用户名长度必须在2-20个字符之间",
nil,
)
}

user := &model.User{
Name: req.Name,
Email: req.Email,
}

if err := s.userDAO.CreateUser(user); err != nil {
return nil, errors.WrapBusinessError(
errors.CodeDBError,
"创建用户失败",
err,
)
}

return user, nil
}

Service层错误传递策略

  • 进行业务参数验证,将无效参数转换为业务错误
  • 使用errors.As()检查错误类型,如果是业务错误直接传递
  • 将DAO层的技术错误包装为业务错误,添加业务语义
  • 实现业务规则验证(如用户状态检查)
  • 不涉及HTTP概念,保持业务逻辑的纯粹性

6. Handler层 (handler/user_handler.go)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
package handler

import (
"net/http"
"strconv"

"github.com/gin-gonic/gin"

"your-project/errors"
"your-project/response"
"your-project/service"
)

type UserHandler struct {
userService *service.UserService
}

func NewUserHandler(userService *service.UserService) *UserHandler {
return &UserHandler{userService: userService}
}

func (h *UserHandler) GetUser(c *gin.Context) {
userIDStr := c.Param("id")
userID, err := strconv.ParseInt(userIDStr, 10, 64)
if err != nil {
response.BadRequest(c, "无效的用户ID")
return
}

user, err := h.userService.GetUser(userID)
if err != nil {
h.handleError(c, err)
return
}

response.Success(c, user)
}

func (h *UserHandler) CreateUser(c *gin.Context) {
var req service.CreateUserRequest
if err := c.ShouldBindJSON(&req); err != nil {
response.BadRequest(c, "请求参数格式错误")
return
}

user, err := h.userService.CreateUser(&req)
if err != nil {
h.handleError(c, err)
return
}

response.Success(c, user)
}

func (h *UserHandler) handleError(c *gin.Context, err error) {
var businessErr *errors.BusinessError
if errors.As(err, &businessErr) {
response.Error(c, businessErr.Code, businessErr.Message)
return
}

response.InternalError(c, err)
}

Handler层错误处理策略

  • 处理HTTP特定错误(参数解析、数据绑定)
  • 统一的错误处理入口handleError
  • 区分业务错误系统错误,分别处理
  • 将错误转换为统一的HTTP响应格式
  • 记录错误日志(在生产环境中可能隐藏内部错误细节)

7. 主程序 (main.go)

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
package main

import (
"database/sql"
"log"

"github.com/gin-gonic/gin"
_ "github.com/go-sql-driver/mysql"

"your-project/dao"
"your-project/handler"
"your-project/service"
)

func main() {
db, err := sql.Open("mysql", "user:password@/dbname")
if err != nil {
log.Fatal("数据库连接失败:", err)
}
defer db.Close()

userDAO := dao.NewUserDAO(db)
userService := service.NewUserService(userDAO)
userHandler := handler.NewUserHandler(userService)

r := gin.Default()

r.GET("/users/:id", userHandler.GetUser)
r.POST("/users", userHandler.CreateUser)

log.Println("服务器启动在 :8080")
r.Run(":8080")
}

错误处理流程示例

成功流程

1
2
3
4
5
6
7
请求: GET /users/123
Handler: 解析参数,调用service.GetUser(123)
Service: 参数验证,调用dao.GetUserByID(123)
DAO: 执行SQL,返回用户数据
Service: 返回用户数据
Handler: response.Success(c, user)
响应: { "code": 0, "message": "success", "data": { ... } }

错误流程(用户不存在)

1
2
3
4
5
6
7
请求: GET /users/999
Handler: 解析参数,调用service.GetUser(999)
Service: 参数验证,调用dao.GetUserByID(999)
DAO: SQL返回ErrNoRows,返回errors.ErrUserNotFound
Service: 传递errors.ErrUserNotFound
Handler: h.handleError → response.Error(10001, "用户不存在")
响应: { "code": 10001, "message": "用户不存在" }

错误流程(数据库连接失败)

1
2
3
4
5
6
7
请求: GET /users/123
Handler: 解析参数,调用service.GetUser(123)
Service: 参数验证,调用dao.GetUserByID(123)
DAO: 数据库连接失败,返回原始错误
Service: 包装为业务错误"获取用户信息失败"
Handler: h.handleError → response.InternalError
响应: { "code": 500, "message": "内部服务器错误" }

各层错误传递的设计哲学

1. 关注点分离 (Separation of Concerns)

  • DAO层只关心数据访问技术细节
  • Service层只关心业务逻辑和规则
  • Handler层只关心HTTP协议和用户交互

2. 错误信息 enrichment(丰富化)

错误在向上传递过程中不断添加上下文信息:

  • DAO: “查询失败”
  • Service: “获取用户信息失败:查询失败”
  • Handler: HTTP 500 + 日志记录

3. 错误类型转换

将底层技术错误转换为高层业务概念:

  • sql.ErrNoRowsErrUserNotFound → HTTP 404
  • sql.ErrConnDoneErrDBError → HTTP 500

4. 防御性编程

每层都进行适当的验证,尽早失败,避免错误传播到不合适的层级。

最佳实践总结

  1. 分层处理:各司其职,避免层间职责混淆
  2. 错误包装:使用错误链保留完整上下文
  3. 统一格式:客户端友好的错误响应格式
  4. 适当日志:在适当层级记录适当详情的日志
  5. 错误分类:区分可预期业务错误和意外系统错误

通过这种架构,可以实现清晰、可维护的错误处理机制,提高代码质量和系统稳定性。

前置并发知识

并发

goroutine

生产者&消费者 模型

Go-pool2.png
生产者-消费者模型是一种经典的并发编程模式,通过缓冲区解耦生产者和消费者,使它们可以独立、异步地工作。

核心组件

  1. 生产者(Producer)

    • 数据的产生者

    • 负责创建任务或数据

    • 将数据放入缓冲区

  2. 消费者(Consumer)

    • 数据的处理者

    • 从缓冲区获取数据

    • 执行具体的业务逻辑

  3. 缓冲区(Buffer/Queue)

    • 生产者和消费者之间的桥梁

    • 平衡生产速度和消费速度的差异

    • 提供流量控制和数据暂存

Go-pool3.png

任务分发的必要性

为什么需要任务分发?

直接创建Goroutine的问题

1
2
3
4
// ❌ 不推荐:无限制创建goroutine
for i := 0; i < 10000; i++ {
go processTask(i) // 可能创建过多goroutine!
}

问题分析

  1. 资源耗尽 - 内存、CPU过载
  2. 调度开销 - 上下文切换成本高
  3. 难以管理 - 无法控制并发数量
    `

协程池

梗概

什么是协程池?

协程池是一种复用Goroutine的技术,通过预先创建固定数量的工作协程,重复使用它们来处理任务,避免频繁创建和销毁的开销。

核心思想

Go-pool1.png

生产者-消费者模型的扩展

  • 生产者:提交任务到任务队列
  • 消费者:工作协程从队列获取任务执行
  • 缓冲区:任务队列平衡生产消费速度

实现思路

核心组件设计

1
2
3
4
5
6
7
8
9
10
11

type WorkerPool struct {
taskChan chan Task // 任务通道(缓冲队列)
resultChan chan Result // 结果通道
stopChan chan struct{} // 停止信号
wg sync.WaitGroup // 等待组(协调goroutine)

// 统计信息(原子操作保证线程安全)
SubmitSum int64 // 已提交任务数
CompleteSum int64 // 已完成任务数
}

工作流程

  1. 初始化阶段:创建指定数量的worker
  2. 任务提交:生产者向taskChan发送任务
  3. 任务处理:worker从taskChan接收并执行
  4. 结果收集:处理结果发送到resultChan
  5. 优雅关闭:通过stopChan协调关闭

代码实现

核心结构定义

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
package workerPool

import (
"errors"
"runtime"
"sync"
"sync/atomic"
)

// TaskFunc 任务函数类型
type TaskFunc func() (interface{}, error)

// Task 任务结构
type Task struct {
ID int // 任务ID(用于追踪)
Func TaskFunc // 要执行的任务函数
}

// TaskResult 任务执行结果
type TaskResult struct {
ID int // 对应任务ID
Result interface{} // 执行结果
Err error // 错误信息
}

// WorkerPool 协程池主体
type WorkerPool struct {
taskChan chan Task // 任务通道(缓冲队列)
resultChan chan TaskResult // 结果通道
stopChan chan struct{} // 停止信号通道
wg sync.WaitGroup // 等待组(协调goroutine生命周期)
// 原子操作统计(线程安全)
SubmitSum int64 // 已提交任务总数
CompleteSum int64 // 已完成任务总数
}

初始化协程池

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29

// NewWorkerPool 创建新的协程池

func NewWorkerPool(workerCount, queueSize int) *WorkerPool {
// 参数校验和默认值设置
if workerCount < 1 {
workerCount = runtime.NumCPU() * 2 // 默认:CPU核心数×2
}

if queueSize < workerCount {
queueSize = workerCount * 100 // 默认队列大小:worker数×100
}

// 初始化协程池实例
pool := &WorkerPool{
taskChan: make(chan Task, queueSize), // 带缓冲的任务通道
resultChan: make(chan TaskResult, queueSize), // 带缓冲的结果通道
stopChan: make(chan struct{}), // 无缓冲停止信号
}

// 创建worker协程
for i := 0; i < workerCount; i++ {
pool.wg.Add(1)
go pool.worker() // 启动worker goroutine
}

return pool
}

生产者:提交任务

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19

// Produce 提交任务到协程池(生产者)

func (p *WorkerPool) Produce(taskFunc TaskFunc) error {
// 封装任务
task := Task{
ID: int(atomic.AddInt64(&p.SubmitSum, 1)), // 原子操作生成任务ID
Func: taskFunc,
}

// 非阻塞发送任务
select {
case p.taskChan <- task: // 正常提交
return nil
case <-p.stopChan: // 协程池已关闭
return errors.New("pool stopped")
}
}

消费者:工作协程

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28

// worker 工作协程(消费者)

func (p *WorkerPool) worker() {
defer p.wg.Done() // 确保goroutine结束时通知WaitGroup

for {
select {
case task, ok := <-p.taskChan:
if !ok { // 通道已关闭且无剩余任务
return
}
// 执行具体任务
result, err := task.Func()
// 发送处理结果
p.resultChan <- TaskResult{
ID: task.ID,
Result: result,
Err: err,
}
// 原子操作更新完成计数
atomic.AddInt64(&p.CompleteSum, 1)
case <-p.stopChan: // 收到停止信号
return
}
}
}

结果收集和管理

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22

// GetResults 获取结果通道(只读)

func (p *WorkerPool) GetResults() <-chan TaskResult {
return p.resultChan
}

// GetInfo 获取统计信息(线程安全)

func (p *WorkerPool) GetInfo() (int64, int64) {
return atomic.LoadInt64(&p.SubmitSum), atomic.LoadInt64(&p.CompleteSum)
}

// Close 优雅关闭协程池

func (p *WorkerPool) Close() {
close(p.taskChan) // 关闭任务通道(停止接收新任务)
p.wg.Wait() // 等待所有worker完成任务
close(p.resultChan) // 关闭结果通道
close(p.stopChan) // 关闭停止信号
}

🎯 实战案例:文件关键词搜索

业务逻辑层

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
package service

import (
"bufio"
"fmt"
"os"
"strings"
"sync/atomic"
)

// Result 搜索结果结构

type Result struct {
Path string // 文件路径
Info []LineInfo // 匹配行信息
Err error // 错误信息
}

// LineInfo 行信息结构

type LineInfo struct {
Line int // 行号
Content string // 行内容
}

// Task 搜索任务

type Task struct {
Path string // 文件路径
Keyword string // 搜索关键词
}

// Search 文件搜索函数

func Search(task Task) (interface{}, error) {

// 打开文件

file, err := os.Open(task.Path)
if err != nil {
return Result{
Path: task.Path,
Err: fmt.Errorf("无法打开文件: %v", err),
}, err
}
defer file.Close()

var info []LineInfo
scanner := bufio.NewScanner(file)
lineNum := 0

// 逐行扫描
for scanner.Scan() {
lineNum++
line := scanner.Text()
if strings.Contains(line, task.Keyword) {
info = append(info, LineInfo{
Line: lineNum,
Content: strings.TrimSpace(line),
})
}
}

// 检查扫描错误
if err := scanner.Err(); err != nil {
return Result{
Path: task.Path,
Err: fmt.Errorf("读取文件错误: %v", err),
}, err
}

return Result{
Path: task.Path,
Info: info,
}, nil
}

// 全局统计(原子操作保证线程安全)

var (
totalFiles int64 // 总文件数
foundFiles int64 // 包含关键词的文件数
totalLines int64 // 总匹配行数
)

// SetTotal 设置总文件数

func SetTotal(num int64) {
atomic.StoreInt64(&totalFiles, num)
}

// AddFound 增加找到的文件计数

func AddFound(num int64) {
atomic.AddInt64(&foundFiles, num)
}

// AddLines 增加匹配行计数

func AddLines(num int64) {
atomic.AddInt64(&totalLines, num)
}

// GetInfo 获取统计信息

func GetInfo() (int, int, int) {
return int(atomic.LoadInt64(&totalFiles)),
int(atomic.LoadInt64(&foundFiles)),
int(atomic.LoadInt64(&totalLines))
}

主程序入口

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182

package main

import (
"Lesson_1/Lanshan-lesson5/service"
"Lesson_1/Lanshan-lesson5/workerPool"
"fmt"
"io/fs"
"os"
"path/filepath"
"runtime"
"sort"
"time"
)

func main() {
// 参数验证
if len(os.Args) != 3 {
fmt.Printf("使用方法: %s [目录路径] [搜索关键词]\n", os.Args[0])
os.Exit(1)
}
dir := os.Args[1]
keyword := os.Args[2]

// 目录存在性检查
if _, err := os.Stat(dir); os.IsNotExist(err) {
fmt.Printf("目录 '%s' 不存在\n", dir)
os.Exit(1)
}

// 初始化配置
workerCount := runtime.NumCPU() * 2
fmt.Printf("搜索目录: '%s', 关键词: '%s'\n", dir, keyword)

startTime := time.Now()

// 创建协程池
pool := workerPool.NewWorkerPool(workerCount, workerCount*100)

// 遍历目录获取文件列表
paths, err := walkDirectory(dir)
if err != nil {
fmt.Printf("遍历目录错误: %v\n", err)
os.Exit(1)
}

service.SetTotal(int64(len(paths)))
fmt.Printf("发现文件数: %d\n", len(paths))

// 结果收集通道
results := make(chan workerPool.TaskResult, workerCount*100)
done := make(chan bool, 1)

// 启动结果收集器
go collectResults(results, done, len(paths))

// 提交搜索任务
submittedTasks := 0
for _, path := range paths {
task := service.Task{Path: path, Keyword: keyword}

err := pool.Produce(func() (interface{}, error) {
return service.Search(task)
})

if err != nil {
fmt.Printf("任务提交失败: %v\n", err)
} else {
submittedTasks++
}
}

fmt.Printf("成功提交任务数: %d\n", submittedTasks)

// 转发结果
go forwardResults(pool.GetResults(), results)

// 等待所有任务完成
pool.Close()
close(results)

<-done // 等待结果收集完成

// 输出统计信息
total, found, lines := service.GetInfo()
elapsed := time.Since(startTime)

fmt.Printf("\n============= 搜索完成 =============\n")
fmt.Printf("总文件数: %d\n", total)
fmt.Printf("包含关键词的文件数: %d\n", found)
fmt.Printf("总匹配行数: %d\n", lines)
fmt.Printf("耗时: %v\n", elapsed)

submitted, completed := pool.GetInfo()
fmt.Printf("任务提交数: %d\n", submitted)
fmt.Printf("任务完成数: %d\n", completed)
}

// walkDirectory 遍历目录获取文件列表

func walkDirectory(dir string) ([]string, error) {
var paths []string

err := filepath.WalkDir(dir, func(path string, d fs.DirEntry, err error) error {
if err != nil {
fmt.Printf("访问错误 '%s': %v\n", path, err)
return nil // 跳过错误文件
}

if !d.IsDir() {
paths = append(paths, path)
}

return nil
})

if err != nil {
return nil, err
}

return paths, nil
}

// forwardResults 结果转发
func forwardResults(source <-chan workerPool.TaskResult, dest chan<- workerPool.TaskResult) {
for result := range source {
dest <- result
}
}

// collectResults 收集和处理结果
func collectResults(results <-chan workerPool.TaskResult, done chan<- bool, total int) {
var finalResults []service.Result
processed := 0

// 处理每个结果
for result := range results {
processed++

// 进度显示
if processed%100 == 0 {
progress := float64(processed) / float64(total) * 100
fmt.Printf("处理进度: %d/%d (%.2f%%)\n", processed, total, progress)
}

// 类型断言获取搜索结果
if searchResult, ok := result.Result.(service.Result); ok {
finalResults = append(finalResults, searchResult)

// 更新统计
if len(searchResult.Info) > 0 {
service.AddFound(1)
service.AddLines(int64(len(searchResult.Info)))
}
}
}
// 输出最终结果
fmt.Printf("\n================搜索结果================\n")
printResults(finalResults)
done <- true
}

// printResults 格式化输出结果
func printResults(results []service.Result) {
// 按文件路径排序
sort.Slice(results, func(i, j int) bool {
return results[i].Path < results[j].Path
})

for _, result := range results {
if result.Err != nil {
fmt.Printf("\n错误: %s - %v\n", result.Path, result.Err)
continue
}
if len(result.Info) > 0 {
fmt.Printf("\n%s:\n", result.Path)
for _, info := range result.Info {
fmt.Printf(" %d: %s\n", info.Line, info.Content)
}
}
}
}
0%