cloverrose's blog

Python, Machine learning, Emacs, CI/CD, Webアプリなど

Go言語による並行処理 4.7 ファンアウト、ファンイン

www.oreilly.co.jp

Go言語による並行処理の本は、仕事でよく使う部分とあんまり使わない部分がある。 よく使う部分をコピペで使いやすくしておく。

ファンアウト、ファンインの例はdone channelだけど仕事ではcontextを使うのでそこは変えてある。

package main

import (
    "context"
    "errors"
    "sync"
    "time"
)

type Value struct {
    value int
}

type generator struct {
    cnt          int
    max          int
    waitDuration time.Duration
}

func (g *generator) Gen() (*Value, error) {
    if g.cnt == g.max {
        return nil, errors.New("done")
    }
    time.Sleep(g.waitDuration)
    g.cnt++
    v := &Value{value: g.cnt}
    return v, nil
}

func batchMaker(ctx context.Context, g *generator, batchSize int) <-chan []*Value {
    batchCh := make(chan []*Value)
    go func() {
        defer close(batchCh)
        batch := make([]*Value, 0, batchSize)
        for {
            value, err := g.Gen()
            if err != nil {
                // send last batch
                select {
                case <-ctx.Done():
                case batchCh <- batch:
                }
                return
            }
            batch = append(batch, value)
            if len(batch) < batchSize {
                continue
            }
            select {
            case <-ctx.Done():
                return
            case batchCh <- batch:
                batch = make([]*Value, 0, batchSize)
            }
        }
    }()
    return batchCh
}

func batchConsumer(ctx context.Context, batchCh <-chan []*Value, fn func(context.Context, []*Value) (*Value, error)) <-chan *Value {
    resultCh := make(chan *Value)
    go func() {
        defer close(resultCh)
        for batch := range batchCh {
            result, err := fn(ctx, batch)
            if err != nil {
                return // or ignore error
            }
            select {
            case <-ctx.Done():
                return
            case resultCh <- result:
            }
        }
    }()
    return resultCh
}

func fanIn(ctx context.Context, channels ...<-chan *Value) <-chan *Value {
    var wg sync.WaitGroup
    merged := make(chan *Value)
    multiplex := func(c <-chan *Value) {
        defer wg.Done()
        for v := range c {
            select {
            case <-ctx.Done():
                return
            case merged <- v:
            }
        }
    }
    wg.Add(len(channels))
    for _, c := range channels {
        go multiplex(c)
    }
    go func() {
        wg.Wait()
        close(merged)
    }()
    return merged
}

func makeCalc(waitDuration time.Duration) func(context.Context, []*Value) (*Value, error) {
    return func(ctx context.Context, batch []*Value) (*Value, error) {
        result := 0
        for _, v := range batch {
            select {
            case <-ctx.Done():
                return nil, ctx.Err()
            case <-time.After(waitDuration):
                result += v.value
            }
        }
        return &Value{value: result}, nil
    }
}

func Demo(ctx context.Context, g *generator, batchSize int, numConsumers int, fn func(context.Context, []*Value) (*Value, error)) {
    batchCh := batchMaker(ctx, g, batchSize)
    resultChs := make([]<-chan *Value, numConsumers)
    for i := 0; i < numConsumers; i++ {
        resultChs[i] = batchConsumer(ctx, batchCh, fn)
    }
    pipeline := fanIn(ctx, resultChs...)
    for _ = range pipeline {
    }
}

パフォーマンステストで本当に速くなってるかも確認

package main

import (
    "context"
    "testing"
    "time"
)

func BenchmarkDemo(b *testing.B) {
    type args struct {
        ctx          context.Context
        g            *generator
        batchSize    int
        numConsumers int
        fn           func(context.Context, []*Value) (*Value, error)
    }
    tests := []struct {
        name    string
        args    args
        wantErr bool
    }{
        {
            name: "no sleep",
            args: args{
                ctx:          context.Background(),
                g:            &generator{max: 100 * 80},
                batchSize:    100,
                numConsumers: 1,
                fn:           makeCalc(0),
            },
            wantErr: false,
        },
        {
            name: "no sleep x7",
            args: args{
                ctx:          context.Background(),
                g:            &generator{max: 100 * 80},
                batchSize:    100,
                numConsumers: 7,
                fn:           makeCalc(0),
            },
            wantErr: false,
        },
        {
            name: "maker and consumer is slow",
            args: args{
                ctx:          context.Background(),
                g:            &generator{max: 100 * 80, waitDuration: 1 * time.Millisecond},
                batchSize:    100,
                numConsumers: 1,
                fn:           makeCalc(1 * time.Millisecond),
            },
            wantErr: false,
        },
        {
            name: "maker and consumer is slow x7",
            args: args{
                ctx:          context.Background(),
                g:            &generator{max: 100 * 80, waitDuration: 1 * time.Millisecond},
                batchSize:    100,
                numConsumers: 7,
                fn:           makeCalc(1 * time.Millisecond),
            },
            wantErr: false,
        },
        {
            name: "consumer is slower",
            args: args{
                ctx:          context.Background(),
                g:            &generator{max: 100 * 80},
                batchSize:    100,
                numConsumers: 1,
                fn:           makeCalc(1 * time.Millisecond),
            },
            wantErr: false,
        },
        {
            name: "consumer is slower x7",
            args: args{
                ctx:          context.Background(),
                g:            &generator{max: 100 * 80},
                batchSize:    100,
                numConsumers: 7,
                fn:           makeCalc(1 * time.Millisecond),
            },
            wantErr: false,
        },
    }
    for _, tt := range tests {
        b.Run(tt.name, func(b *testing.B) {
            Demo(tt.args.ctx, tt.args.g, tt.args.batchSize, tt.args.numConsumers, tt.args.fn)
        })
    }
}