Contents

Concurrency in Go

const (
	mutexLocked = 1 << iota
	mutexWoken
	mutexStarving
	mutexWaiterShift = iota
  starvationThresholdNs = 1e6
)
// MSB 29: number of waiters
// LSB 3:
// starving(100)
// woken(010)
// locked(001)

type Mutex struct {
	state int32
	sema  uint32
}
  • competitive wakeup
  • handoff wakeup
func (m *Mutex) Unlock() {
	new := atomic.AddInt32(&m.state, -mutexLocked)
	if new != 0 {
		m.unlockSlow(new)
	}
}

func (m *Mutex) unlockSlow(new int32) {
	if (new+mutexLocked)&mutexLocked == 0 {
		fatal("sync: unlock of unlocked mutex")
	}
	if new&mutexStarving == 0 {
		old := new
		for {
			// If there are no waiters or a goroutine has already
			// been woken or grabbed the lock, no need to wake anyone.
			// In starvation mode ownership is directly handed off from unlocking
			// goroutine to the next waiter. We are not part of this chain,
			// since we did not observe mutexStarving when we unlocked the mutex above.
			// So get off the way.
			if old>>mutexWaiterShift == 0 || old&(mutexLocked|mutexWoken|mutexStarving) != 0 {
				return
			}
			// Grab the right to wake someone.
			new = (old - 1<<mutexWaiterShift) | mutexWoken
			if atomic.CompareAndSwapInt32(&m.state, old, new) {
				runtime_Semrelease(&m.sema, false, 1)
				return
			}
			old = m.state
		}
	} else {
		// Starving mode: handoff mutex ownership to the next waiter, and yield
		// our time slice so that the next waiter can start to run immediately.
		// Note: mutexLocked is not set, the waiter will set it after wakeup.
		// But mutex is still considered locked if mutexStarving is set,
		// so new coming goroutines won't acquire it.
		runtime_Semrelease(&m.sema, true, 1)
	}
}
  • 4 actions:

    • spin
    • compete state
    • sleep
    • wake
  • when starving:

    • goroutine won’t compete the lock (i.e. change state from unlocked to locked)
    • goroutine is going to sleep
    • goroutine won’t spin
  • when woken:

    • competitive wakeup won’t happen
  • when locked:

    • goroutine can’t compete the lock
    • goroutine is going to sleep
    • goroutine will spin when not starving
func (m *Mutex) Lock() {
	if atomic.CompareAndSwapInt32(&m.state, 0, mutexLocked) {
		return // succeed to lock
	}
	m.lockSlow()
}

func (m *Mutex) lockSlow() {
  	var waitStartTime int64
	starving := false
	awoke := false
	iter := 0
	old := m.state
	for {
    	// action: spin
		if old&(mutexLocked|mutexStarving) == mutexLocked && runtime_canSpin(iter) { // locked but not starving
			if !awoke &&
				old&mutexWoken == 0 &&
				old>>mutexWaiterShift != 0 && // there are other waiters
				atomic.CompareAndSwapInt32(&m.state, old, old|mutexWoken) { // set mutexWoken flag to inform `Unlock` to not wake blocked goroutines
				awoke = true // set current goroutine woken
			}
			runtime_doSpin()
			iter++
			old = m.state // read state
			continue
		}

		new := old
		if old&mutexStarving == 0 { // not starving
			new |= mutexLocked
		}
		if old&(mutexLocked|mutexStarving) != 0 { // locked or starving (when starving, goroutine does not compete the lock)
			new += 1 << mutexWaiterShift
		}
		if starving && old&mutexLocked != 0 { // locked
			new |= mutexStarving
		}
		if awoke { // 1. awoke from competitive wakeup  2. awoke from spinning
			if new&mutexWoken == 0 {
				throw("sync: inconsistent mutex state")
			}
			new &^= mutexWoken // and not operation, clear mutexWoken flag
		}

		// action: compete state
		if atomic.CompareAndSwapInt32(&m.state, old, new) { // if no other goroutine has changed the state, load the state of the current goroutine
			if old&(mutexLocked|mutexStarving) == 0 { // not locked and not starving
				break // has aquired the lock
			}

			// action: sleep
			queueLifo := waitStartTime != 0 // whether already waiting before
			if waitStartTime == 0 {
				waitStartTime = runtime_nanotime()
			}
			runtime_SemacquireMutex(&m.sema, queueLifo, 1) // if waiting before, queue at the front of the queue, vice versa

			// action: wake (from handoff or competitive wakeup)
			starving = starving || runtime_nanotime()-waitStartTime > starvationThresholdNs // exceeds 1ms
			old = m.state
			if old&mutexStarving != 0 { // handoff wakeup
				if old&(mutexLocked|mutexWoken) != 0 || old>>mutexWaiterShift == 0 { // locked or woken or no waiters
					throw("sync: inconsistent mutex state")
				}
				delta := int32(mutexLocked - 1<<mutexWaiterShift) // aquire the lock and decr number of waiters
				if !starving || old>>mutexWaiterShift == 1 {
					delta -= mutexStarving //exit starvation mode
				}
				atomic.AddInt32(&m.state, delta)
				break
			}
			awoke = true
			iter = 0
		} else {
			old = m.state
		}
  }
}
type WaitGroup struct {
	state atomic.Uint64 // high 32 bits are counter, low 32 bits are waiter count.
	sema  uint32
}

// Wait blocks until the WaitGroup counter is zero.
func (wg *WaitGroup) Wait() {
	for {
		state := wg.state.Load()
		v := int32(state >> 32)
		if v == 0 {
			// Counter is 0, no need to wait.
			return
		}
		if wg.state.CompareAndSwap(state, state+1) { // increment waiter count
			runtime_Semacquire(&wg.sema)

			// woken by Add
			if wg.state.Load() != 0 {
				panic("sync: WaitGroup is reused before previous Wait has returned")
			}
			return
		}
	}
}

func (wg *WaitGroup) Done() {
	wg.Add(-1)
}

func (wg *WaitGroup) Add(delta int) {
	state := wg.state.Add(uint64(delta) << 32)
	v := int32(state >> 32)
	w := uint32(state)
	if v < 0 {
		panic("sync: negative WaitGroup counter")
	}
	if w != 0 && delta > 0 && v == int32(delta) {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	if v > 0 || w == 0 {
		return
	}
	// This goroutine has set counter to 0 when waiters > 0.
	// Now there can't be concurrent mutations of state:
	// - Adds must not happen concurrently with Wait,
	// - Wait does not increment waiters if it sees counter == 0.
	// Still do a cheap sanity check to detect WaitGroup misuse.
	if wg.state.Load() != state {
		panic("sync: WaitGroup misuse: Add called concurrently with Wait")
	}
	// Reset waiters count to 0.
	wg.state.Store(0)
	for ; w != 0; w-- {
		runtime_Semrelease(&wg.sema, false, 0)
	}
}

type Cond struct {
	// L is held while observing or changing the condition
	L Locker
	notify  notifyList
}

func NewCond(l Locker) *Cond {
	return &Cond{L: l}
}

func (c *Cond) Wait() {
	t := runtime_notifyListAdd(&c.notify)
	c.L.Unlock()
	runtime_notifyListWait(&c.notify, t)
	c.L.Lock()
}

func (c *Cond) Signal() {
	runtime_notifyListNotifyOne(&c.notify)
}

func (c *Cond) Broadcast() {
	runtime_notifyListNotifyAll(&c.notify)
}
package main

import (
	"log"
	"math/rand"
	"sync"
	"time"
)

func main() {
	c := sync.NewCond(&sync.Mutex{})
	var ready int
	for i := 0; i < 10; i++ {
		go func(i int) {
			time.Sleep(time.Duration(rand.Int63n(10)) * time.Second)

			c.L.Lock()
			ready++
			c.L.Unlock()

			log.Printf("运动员#%d 已准备就绪\n", i)
			c.Broadcast()
		}(i)
	}

	c.L.Lock()
	for ready != 10 {
		c.Wait()
		log.Println("裁判员被唤醒一次")
	}
	c.L.Unlock()

	log.Println("所有运动员都准备就绪。比赛开始,3,2,1, ......")
}

type Once struct {
	done uint32
	m    Mutex
}

func (o *Once) Do(f func()) {
	if atomic.LoadUint32(&o.done) == 0 {
		o.doSlow(f)
	}
}

func (o *Once) doSlow(f func()) {
	o.m.Lock()
	defer o.m.Unlock()
	if o.done == 0 {
		defer atomic.StoreUint32(&o.done, 1)
		f()
	}
}
type RWMap struct { // 一个读写锁保护的线程安全的map
    sync.RWMutex // 读写锁保护下面的map字段
    m map[int]int
}
// 新建一个RWMap
func NewRWMap(n int) *RWMap {
    return &RWMap{
        m: make(map[int]int, n),
    }
}
func (m *RWMap) Get(k int) (int, bool) { //从map中读取一个值
    m.RLock()
    defer m.RUnlock()
    v, existed := m.m[k] // 在锁的保护下从map中读取
    return v, existed
}

func (m *RWMap) Set(k int, v int) { // 设置一个键值对
    m.Lock()              // 锁保护
    defer m.Unlock()
    m.m[k] = v
}

func (m *RWMap) Delete(k int) { //删除一个键
    m.Lock()                   // 锁保护
    defer m.Unlock()
    delete(m.m, k)
}

func (m *RWMap) Len() int { // map的长度
    m.RLock()   // 锁保护
    defer m.RUnlock()
    return len(m.m)
}

func (m *RWMap) Each(f func(k, v int) bool) { // 遍历map
    m.RLock()             //遍历期间一直持有读锁
    defer m.RUnlock()

    for k, v := range m.m {
        if !f(k, v) {
            return
        }
    }
}
package main

import (
	"sync"
)

var SHARD_COUNT = 32

type ShardMap []*RWMap

type RWMap struct {
	items map[string]any
	sync.RWMutex
}

func New() ShardMap {
	m := make(ShardMap, SHARD_COUNT)
	for i := 0; i < SHARD_COUNT; i++ {
		m[i] = &RWMap{
			items:   map[string]any{},
			RWMutex: sync.RWMutex{},
		}
	}
	return m
}

func (m ShardMap) GetShard(key string) *RWMap {
	return m[uint(fnv32(key))%uint(SHARD_COUNT)]
}

func fnv32(key string) uint32 {
	hash := uint32(2166136261)
	const prime32 = uint32(16777619)
	keyLength := len(key)
	for i := 0; i < keyLength; i++ {
		hash *= prime32
		hash ^= uint32(key[i])
	}
	return hash
}

/images/318.png

type Map struct {
	mu     Mutex          // 对 dirty 加锁保护
	read   atomic.Value   // 只读的 map
	dirty  map[any]*entry // 负责写的 map
	misses int            // read 被穿透时+1,当misses = len(dirty)时,将其赋值给read
}

type readOnly struct {
	m       map[any]*entry //
	amended bool           // 表示dirty的数据和 m 的数据不一样
}

type entry struct {
	p unsafe.Pointer
}
  • 读写分离
    • read 和 dirty 各自维护一套 key,且指向同一个 value
    • 通过读写分离,降低锁时间来提高效率,适用于读多写少的场景
    • 新 key 写到 dirty 中,若为 nil 则创建,并将 read 中未被标记删除的元素拷贝到 dirty
  • 延迟删除
    • read 通过 CAS 将对应的 value 情零
package main

import (
	"fmt"
	"sync"
	"time"
)

type Data struct {
	Value int
}

func main() {
	dataPool := sync.Pool{
		New: func() any {
			return &Data{}
		},
	}
	data := dataPool.Get().(*Data)
	data.Value = 42
	fmt.Println("Data:", data.Value)
	dataPool.Put(data)

	go func() {
		anotherData := dataPool.Get().(*Data)
		fmt.Println("Another Data:", anotherData.Value)
	}()

	time.Sleep(time.Second)
}

/images/317.png

// A Pool must not be copied after first use.
type Pool struct {
	noCopy noCopy

	local     unsafe.Pointer // local fixed-size per-P pool, actual type is [P]poolLocal
	localSize uintptr        // size of the local array

	victim     unsafe.Pointer // local from previous cycle
	victimSize uintptr        // size of victims array

	// New optionally specifies a function to generate
	// a value when Get would otherwise return nil.
	// It may not be changed concurrently with calls to Get.
	New func() any
}

// Local per-P Pool appendix.
type poolLocalInternal struct {
	private any       // Can be used only by the respective P.
	shared  poolChain // Local P can pushHead/popHead; any P can popTail.
}

type poolLocal struct {
	poolLocalInternal

	// Prevents false sharing on widespread platforms with
	// 128 mod (cache line size) = 0 .
	pad [128 - unsafe.Sizeof(poolLocalInternal{})%128]byte
}
  • 从本地的 private 字段中获取可用元素
  • 如果没有,就从本地的 shared 获取一个
  • 如果没有,就从其它的 shared 中偷一个
  • 如果没有,就从 victim 中找
  • 如果没有,就使用 New 函数创建一个
  • 如果没有设置 New 方法,将返回 nil
  • 设置本地 private
  • 如果 private 已设置,就把此元素 push 到本地队列中

Pool 是如何实现并发安全的?

  • 对于生产者而言,当前 P 操作的都是自身的 poolLocal,避免了数据竞争
  • 对于消费者,对其他的 P 会进行 popTail 操作,这时会和 pushHead/popHead 操作形成数据竞争。通过原子操作避免了读写冲突
package main

import (
	"context"
	"fmt"
)

func main() {
	ctx := context.TODO()
	ctx = context.WithValue(ctx, "key1", "0001")
	ctx = context.WithValue(ctx, "key2", "0002")
	ctx = context.WithValue(ctx, "key3", "0003")
	ctx = context.WithValue(ctx, "key4", "0004")

	fmt.Println(ctx.Value("key1"))
	fmt.Println(ctx.Value("key3"))
}
package main

import (
	"context"
	"fmt"
	"sync"
)

func main() {
	wg := sync.WaitGroup{}
	wg.Add(1)
	ctx, cancel := context.WithCancel(context.Background())
	go func() {
		defer func() {
			fmt.Println("goroutine exit")
			wg.Done()
		}()
		for {
			select {
			case <-ctx.Done():
				return
			}
		}
	}()
	fmt.Println("to cancel")
	cancel()
	wg.Wait()
}
package main

import (
	"context"
	"fmt"
	"sync"
	"time"
)

func main() {
	wg := sync.WaitGroup{}
	wg.Add(1)
	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
	defer cancel()
	go func() {
		defer func() {
			fmt.Println("goroutine exit")
			wg.Done()
		}()
		for {
			select {
			case <-ctx.Done():
				return
			}
		}
	}()
	wg.Wait()
}

交替打印字母、数字

package main

import (
	"fmt"
	"sync"
)

type Token struct{}

func pass(msg any, ch, nextCh chan Token, wg *sync.WaitGroup) {
	for i := 0; i < 3; i++ {
		token := <-ch // 取得令牌
		fmt.Println(msg)
		nextCh <- token // 传递令牌
	}
	wg.Done()
}
func main() {
	chs := []chan Token{
		make(chan Token, 1),
		make(chan Token, 1),
	}
	wg := &sync.WaitGroup{}
	wg.Add(2)
	go pass("A", chs[0], chs[1], wg)
	go pass(0, chs[1], chs[0], wg)
	chs[0] <- Token{}
	wg.Wait()
}

多个 goroutine 工作,只要其中一个完成任务发送信号

多个源 channel 输入、一个目的 channel 输出的情况

反射

package main

import "reflect"

func fanInReflect(chans ...<-chan any) <-chan any {
	out := make(chan any)
	go func() {
		defer close(out)
		// 构造SelectCase slice
		var cases []reflect.SelectCase
		for _, c := range chans {
			cases = append(cases, reflect.SelectCase{
				Dir:  reflect.SelectRecv,
				Chan: reflect.ValueOf(c),
			})
		}
		// 循环,从cases中选择一个可用的
		for len(cases) > 0 {
			i, v, ok := reflect.Select(cases)
			if !ok { // 此channel已经close
				cases = append(cases[:i], cases[i+1:]...)
				continue
			}
			out <- v.Interface()
		}
	}()
	return out
}

递归

package main

func fanInRec(chans ...<-chan any) <-chan any {
	switch len(chans) {
	case 0:
		c := make(chan any)
		close(c)
		return c
	case 1:
		return chans[0]
	default:
		m := len(chans) / 2
		return mergeTwo(
			fanInRec(chans[:m]...),
			fanInRec(chans[m:]...))
	}
}

func mergeTwo(a, b <-chan any) <-chan any {
	c := make(chan any)
	go func() {
		defer close(c)
		for a != nil || b != nil { //只要还有可读的chan
			select {
			case v, ok := <-a:
				if !ok { // a 已关闭,设置为nil
					a = ni
					continue
				}
				c <- v
			case v, ok := <-b:
				if !ok { // b 已关闭,设置为nil
					b = nil
					continue
				}
				c <- v
			}
		}
	}()
	return c
}

package main

func fanOut(ch <-chan any, out []chan any, async bool) {
	go func() {
		defer func() { //退出时关闭所有的输出chan
			for i := 0; i < len(out); i++ {
				close(out[i])
			}
		}()
		for v := range ch { // 从输入chan中读取数据
			for i := 0; i < len(out); i++ {
				if async { //异步
					go func() {
						out[i] <- v // 放入到输出chan中,异步方式
					}()
				} else {
					out[i] <- v // 放入到输出chan中,同步方式
				}
			}
		}
	}()
}
package main

func asStream(done <-chan struct{}, values ...any) <-chan any {
	s := make(chan any)
	go func() {
		defer close(s)             // 退出时关闭chan
		for _, v := range values { // 遍历数组
			select {
			case <-done:
				return
			case s <- v: // 将数组元素塞入到chan中
			}
		}
	}()
	return s
}

func takeN(done <-chan struct{}, valueStream <-chan any, num int) <-chan any {
	takeStream := make(chan any) // 创建输出流
	go func() {
		defer close(takeStream)
		for i := 0; i < num; i++ { // 只读取前num个元素
			select {
			case <-done:
				return
			case takeStream <- <-valueStream: //从输入流中读取元素
			}
		}
	}()
	return takeStream
}
package main

func mapChan(in <-chan any, fn func(any) any) <-chan any {
	out := make(chan any) //创建一个输出chan
	if in == nil {                // 异常检查
		close(out)
		return out
	}
	go func() { // 启动一个goroutine,实现map的主要逻辑
		defer close(out)
		for v := range in { // 从输入chan读取数据,执行业务操作,也就是map操作
			out <- fn(v)
		}
	}()
	return out
}

func reduce(in <-chan any, fn func(r, v any) any) any {
	if in == nil { // 异常检查
		return nil
	}
	out := <-in // 先读取第一个元素
	for v := range in { // 实现reduce的主要逻辑
		out = fn(out, v)
	}
	return out
}
package main

import (
	"context"
	"fmt"
	"golang.org/x/sync/semaphore"
	"time"
)

var (
	maxWorkers int64 = 10                                // worker数量
	sema             = semaphore.NewWeighted(maxWorkers) //信号量
	tasks            = make([]int, maxWorkers*4)         // 任务数,是worker的四倍
)

func main() {
	ctx := context.Background()
	for i := range tasks {
		// 如果没有worker可用,会阻塞在这里,直到某个worker被释放
		sema.Acquire(ctx, 1)
		// 启动worker goroutine
		go func(i int) {
			defer sema.Release(1)
			// 模拟一个耗时操作
			time.Sleep(100 * time.Millisecond)
			tasks[i] = i + 1
		}(i)
	}
	// 确保前面的worker都执行完
	sema.Acquire(ctx, maxWorkers)
	fmt.Println(tasks)
}
package main

import (
	"fmt"
	"golang.org/x/sync/singleflight"
	"math/rand"
	"sync"
)

var getDigit = func() (interface{}, error) {
		return rand.Intn(10), nil
	}

func main() {
	var wg sync.WaitGroup
	var g singleflight.Group
	var n = 10
	wg.Add(n)
	for i := 0; i < n; i++ {
		go func() {

			v, err, shared := g.Do("key", getDigit)
			fmt.Println(v, err, shared)
			wg.Done()
		}()
	}
	wg.Wait()
}
type call struct {
	wg  sync.WaitGroup // 其他 call 阻塞
	val any            // 返回结果
	err error          // 返回错误
}

type flightGroup struct {
	calls map[string]*call
	lock  sync.Mutex
}

func (g *flightGroup) Do(key string, fn func() (any, error)) (any, error) {
	c, done := g.createCall(key)
	if done {
		return c.val, c.err
	}

	g.makeCall(c, key, fn)
	return c.val, c.err
}

func (g *flightGroup) createCall(key string) (c *call, done bool) {
	g.lock.Lock()
	if c, ok := g.calls[key]; ok {
		g.lock.Unlock()
		c.wg.Wait()
		return c, true
	}

	c = new(call)
	c.wg.Add(1)
	g.calls[key] = c
	g.lock.Unlock()

	return c, false
}

func (g *flightGroup) makeCall(c *call, key string, fn func() (any, error)) {
	defer func() {
		g.lock.Lock()
		delete(g.calls, key)
		g.lock.Unlock()
		c.wg.Done()
	}()

	c.val, c.err = fn()
}
package main

import (
	"context"
	"github.com/marusama/cyclicbarrier"
	"golang.org/x/sync/semaphore"
	"sync"
)

func main() {
	// 300个原子,300个goroutine,每个goroutine并发的产生一个原子
	N := 100
	ch := make(chan string, N*3)

	// 用来等待所有的goroutine完成
	var wg sync.WaitGroup
	wg.Add(N * 3)

	h2o := New()
	// 200个氢原子goroutine
	for i := 0; i < 2*N; i++ {
		go func() {
			h2o.hydrogen(ch)
			wg.Done()
		}()
	}
	// 100个氧原子goroutine
	for i := 0; i < N; i++ {
		go func() {
			h2o.oxygen(ch)
			wg.Done()
		}()
	}

	//等待所有的goroutine执行完
	wg.Wait()

	// 每三个原子一组,要求这一组原子中必须包含两个氢原子和一个氧原子,这样才能正确组成一个水分子
	var s = make([]string, 3)
	for i := 0; i < N; i++ {
		s[0] = <-ch
		s[1] = <-ch
		s[2] = <-ch
		water := s[0] + s[1] + s[2]
		println(water)
	}
}

type H2O struct {
	semaH *semaphore.Weighted         // 氢原子的信号量
	semaO *semaphore.Weighted         // 氧原子的信号量
	b     cyclicbarrier.CyclicBarrier // 循环栅栏,用来控制合成
}

func New() *H2O {
	return &H2O{
		semaH: semaphore.NewWeighted(2), //氢原子需要两个
		semaO: semaphore.NewWeighted(1), // 氧原子需要一个
		b:     cyclicbarrier.New(3),     // 需要三个原子才能合成
	}
}
func (h2o *H2O) hydrogen(ch chan string) {
	h2o.semaH.Acquire(context.Background(), 1)
	ch <- "H"                         // 输出H
	h2o.b.Await(context.Background()) //等待栅栏放行
	h2o.semaH.Release(1)              // 释放氢原子空槽
}
func (h2o *H2O) oxygen(ch chan string) {
	h2o.semaO.Acquire(context.Background(), 1)
	ch <- "O"                         // 输出O
	h2o.b.Await(context.Background()) //等待栅栏放行
	h2o.semaO.Release(1)              // 释放氢原子空槽
}
package main

import (
	"errors"
	"fmt"
	"time"

	"golang.org/x/sync/errgroup"
)

func main() {
	var g errgroup.Group
	// 启动第一个子任务,它执行成功
	g.Go(func() error {
		time.Sleep(5 * time.Second)
		fmt.Println("exec #1")
		return nil
	})
	// 启动第二个子任务,它执行失败
	g.Go(func() error {
		time.Sleep(10 * time.Second)
		fmt.Println("exec #2")
		return errors.New("failed to exec #2")
	})
	// 启动第三个子任务,它执行成功
	g.Go(func() error {
		time.Sleep(15 * time.Second)
		fmt.Println("exec #3")
		return nil
	})
	// 等待三个任务都完成
	if err := g.Wait(); err == nil {
		fmt.Println("Successfully exec all")
	} else { // 返回递一个错误
		fmt.Println("failed:", err)
	}
}