The Go Programming Language
http://golang.org/
Go Playground
Go Projects
Revel Web Framework
hetiansu5

Go 之 WaitGroup 底层实现

  •  
  •   hetiansu5 · Apr 4, 2021 · 2512 views
    This topic created in 1865 days ago, the information mentioned may be changed or developed.

    WaitGroup

    WaitGroup 用于等待一组线程的结束,父线程调用 Add 来增加等待的线程数,被等待的线程在结束后调用 Done 来将等待线程数减 1,父线程通过调用 Wait 阻塞等待所有结束(计数器清零)后进行唤醒。

    源码位置

    WaitGroup 的源码在 SDK 包的路径为src/sync/waitgroup.go

    数据结构

    type WaitGroup struct {
    	noCopy noCopy
    	state1 [3]uint32
    }
    

    1.noCopy noCopy

    noCopy 这个主要用来限制不能进行 copy,这里是为了避免 copy 后的 waitGroup 并发使用后,可能会与原 waitGroup 出现异常而 panic 。

    2.state1 [3]unit32

    数组的三个元素(非顺序):

    • counter 通过 Add()设置的子 goroutine 的数量,即被等待线程计数
    • waiter 通过 Wait()陷入阻塞的等待者计数
    • semap 信号量,用于唤醒阻塞 waiter

    这里需要注意一下 couter 、waiter 、semap 并不是顺序存储的,64bit 操作系统的原子操作需要保证 64bit 的内存对齐,在设计上我们需要保证 couter 和 waiter 的操作原子性。如果数组的首元素地址能被 8 整除,则 counter 和 waiter 刚好可以在同一块原子操作的 64bit 内存上,所以取数组前两个元素分别表示 couter 和 waiter ;如果不能被 8 整除(根据内存对齐的原理,地址必然是 4 的倍数),则取数组后两个。

    // 根据内存对齐方式的不同,返回 statep(couter 占用高 32bit 和 waiter 占用低 32bit)和 semap 的地址
    func (wg *WaitGroup) state() (statep *uint64, semap *uint32) {
    	if uintptr(unsafe.Pointer(&wg.state1))%8 == 0 {
    		return (*uint64)(unsafe.Pointer(&wg.state1)), &wg.state1[2]
    	} else {
    		return (*uint64)(unsafe.Pointer(&wg.state1[1])), &wg.state1[0]
    	}
    }
    

    alignment.png

    公共方法

    func (wg *WaitGroup) Add(delta int) //增加 waitGroup 子 goruntine 计数值
    func (wg *WaitGroup) Done() //当子 goruntine 完成后,将计数器-1
    func (wg *WaitGroup) Wait() //调用此方法的 goruntine,阻塞等待计数值为 0
    

    以下方法去除了 race 竞争检查的源代码。

    Add

    操作 counter 计数值加减。

    • 当 counter 增加时,直接 return
    • 当 counter 减少时, 判断条件:counter > 0 || waiter == 0
      • true 时,直接 return
      • false (等待线程都完成且有等待者)时,statep 复位为 0,通过 semap 信号量唤醒所有等待者
    func (wg *WaitGroup) Add(delta int) {
    	//从数组中拿到 stetep ( counter+waiter 的组合)和 semap 信号量的内存地址
    	statep, semap := wg.state()
    	//stetep 原子加操作,高位 32bit 是 counter,实际 counter+1
    	state := atomic.AddUint64(statep, uint64(delta)<<32)
    	//state 的高位 32bit,表示 couter 的计数值
    	v := int32(state >> 32)
    	//state 的低位 32bit,表示 waiter 的等待者数量
    	w := uint32(state)
    	// couter 不能小于 0
    	if v < 0 {
    		panic("sync: negative WaitGroup counter")
    	}
    	// 需要避免错误操作:Add 和 Wait 并发操作,否则会 panic
    	if w != 0 && delta > 0 && v == int32(delta) {
    		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    	}
    	// 如果还有等待线程未完成或者并没有等待者,直接 return
    	if v > 0 || w == 0 {
    		return
    	}
    	// 需要避免错误操作:Add 和 Wait 并发操作,否则会 panic
    	if *statep != state {
    		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
    	}
    	// 将 statep 复位为 0 ( counter 和 waiter 都置为 0 )
    	*statep = 0
    	// 有多少个等待者就往 semap 循环发信号量(其实就是 semap+1 ),Wait 等待有一个调用	// runtime_Semacquire(semap)就是在等待这个信号量
    	for ; w != 0; w-- {
    		runtime_Semrelease(semap, false, 0)
    	}
    }
    

    Done

    被等待线程完成后调用 Done,将 counter 计数-1,表示线程结束

    func (wg *WaitGroup) Done() {
    	wg.Add(-1)
    }
    

    Wait

    主线程循环对 waiter 原子操作+1 直到成功后,然后阻塞等待 semap 信号量而被唤醒,最后 return

    func (wg *WaitGroup) Wait() {
    	// 从数组中拿到 stetep ( counter+waiter 的组合)和 semap 信号量的内存地址
    	statep, semap := wg.state()
    	for {
    		//从内存总线中加载最新的 statep 值
    		state := atomic.LoadUint64(statep)
    		//state 的高位 32bit,表示 couter 的计数值
    		v := int32(state >> 32)
    		//state 的低位 32bit,表示 waiter 的等待者数量
    		w := uint32(state)
    		//如果 couter 为 0,表示当前已经没有在运行的等待线程了
    		if v == 0 {
    			return
    		}
    		// CAS 操作 statep+1,低位属于 waiter,即 waiter+1
    		if atomic.CompareAndSwapUint64(statep, state, state+1) {
    			// CAS 操作成功后,阻塞等待 semap 信号为非零,竞争到会将 semap-1,并唤醒线程
    			runtime_Semacquire(semap)
    			if *statep != 0 {
    				panic("sync: WaitGroup is reused before previous Wait has returned")
    			}
    			return
    		}
    		// CAS 操作失败了,重新进入循环
    	}
    }
    
    4 replies    2021-04-07 13:32:21 +08:00
    makdon
        1
    makdon  
       Apr 4, 2021   ❤️ 2
    拉到最后竟然没有公众号 /博客 /培训班 /招聘
    raaaaaar
        2
    raaaaaar  
       Apr 4, 2021 via Android
    最近学了操作系统,发现就是个二元信号量。。
    hetiansu5
        3
    hetiansu5  
    OP
       Apr 6, 2021
    @makdon 哈哈,单纯输出而已,变相的加深理解
    kuro1
        4
    kuro1  
       Apr 7, 2021
    拉到最后竟然没有公众号 /博客 /培训班 /招聘+1
    About   ·   Help   ·   Advertise   ·   Blog   ·   API   ·   FAQ   ·   Solana   ·   4245 Online   Highest 6679   ·     Select Language
    创意工作者们的社区
    World is powered by solitude
    VERSION: 3.9.8.5 · 43ms · UTC 05:15 · PVG 13:15 · LAX 22:15 · JFK 01:15
    ♥ Do have faith in what you're doing.