harryのブログ

ロードバイクとか模型とかゲームについて何か書いてあるかもしれません

続: RustでなぜかTLEするコードを調べた

前回: RustでなぜかTLEするコードを調べた - harryのブログ

「なぜか」とは???

  • 定数倍が軽い計算量で数値計算をするコードで TLE する
    • または制限時間 5s ギリギリ
  • 対象の問題は、AtCoder 上の競プロ典型90問「055 - Select 5(★2)
  • 全テストケースに対してAC解は出力できてる

tl;dr

問題とコード

問題文

問題文

 \displaystyle N 個の整数  \displaystyle A1,A2,⋯,ANがあります。 この中から 5 個を選ぶ方法のうち、これら 5 個の整数の積を  \displaystyle P で割ると  \displaystyle Q 余るようなものが何通りあるか求めてください。

制約

  •  \displaystyle 5 \leq N \leq 100
  •  \displaystyle 0 \leq Ai \leq 10^{9}
  •  \displaystyle 0 \leq Q \lt P \leq 10^{9}
  • 入力はすべて整数

コード

use proconio::fastout;
use proconio::input;

#[fastout]
fn main() {
    input! {
        n: usize,
        p: i64,
        q: i64,
        a: [i64; n]
    }

    let mut count = 0;
    for i in 0..n-4 {
        for j in i+1..n-3 {
            for k in j+1..n-2 {
                for l in k+1..n-1 {
                    for m in l+1..n {
                        let mut r = 1;
                        r *= a[i]; r %= p;
                        r *= a[j]; r %= p;
                        r *= a[k]; r %= p;
                        r *= a[l]; r %= p;
                        r *= a[m]; r %= p;
                        if r == q {
                            count += 1;
                        }
                    }
                }
            }
        }
    }

    println!("{}", count);
}

提出 #43894703

原因調査

さて、上記のコードの何が問題なのでしょうか?

今見ると、r の初期値は a[i] で直後の r %= p; 要らなくね?みたいなところはある。あと可読性。

計測してみる

  • 早速 callgrind で計測してみる
  • AtCoder の Rust のバージョンは 1.42.0
    • Rust の docker image を使って version を合わせる
  • 一番遅いテストケースは hand_06.txt なので、この入力で計測
    •  \displaystyle N = 100
    •  \displaystyle P = 536870912
    •  \displaystyle Q = 0
  • 前回同様、ファイルから入力を読み込むようにコードを書き換えておく
    • ハードコーディングした場合、入力値に対して最適化される可能性があるため

実行結果

# cargo build --release
# time ./target/release/typical90_055
30980356

real    0m3.228s
user    0m3.228s
sys     0m0.000s
# valgrind --tool=callgrind --cache-sim=yes --branch-sim=yes ./target/release/typical90_055
==27== Callgrind, a call-graph generating cache profiler
==27== Copyright (C) 2002-2017, and GNU GPL'd, by Josef Weidendorfer et al.
==27== Using Valgrind-3.14.0 and LibVEX; rerun with -h for copyright info
==27== Command: ./typical90_055
==27==
--27-- warning: L3 cache found, using its data for the LL simulation.
==27== For interactive control, run 'callgrind_control -h'.
30980356
==27==
==27== Events    : Ir Dr Dw I1mr D1mr D1mw ILmr DLmr DLmw Bc Bcm Bi Bim
==27== Collected : 3680079263 460338456 31662019 1808 4192 1124 1712 2868 1001 986694589 17990006 2120 839
==27==
==27== I   refs:      3,680,079,263
==27== I1  misses:            1,808
==27== LLi misses:            1,712
==27== I1  miss rate:          0.00%
==27== LLi miss rate:          0.00%
==27==
==27== D   refs:        492,000,475  (460,338,456 rd + 31,662,019 wr)
==27== D1  misses:            5,316  (      4,192 rd +      1,124 wr)
==27== LLd misses:            3,869  (      2,868 rd +      1,001 wr)
==27== D1  miss rate:           0.0% (        0.0%   +        0.0%  )
==27== LLd miss rate:           0.0% (        0.0%   +        0.0%  )
==27==
==27== LL refs:               7,124  (      6,000 rd +      1,124 wr)
==27== LL misses:             5,581  (      4,580 rd +      1,001 wr)
==27== LL miss rate:            0.0% (        0.0%   +        0.0%  )
==27==
==27== Branches:        986,696,709  (986,694,589 cond +      2,120 ind)
==27== Mispredicts:      17,990,845  ( 17,990,006 cond +        839 ind)
==27== Mispred rate:            1.8% (        1.8%     +       39.6%   )

なるほど、わからん。

  • I refs (Instruction cache references) 多い気がしないでもない…?
    • このテストケースは \displaystyle N=100なので、ループ回数は  \displaystyle {}_{100} C_5 = 75,287,520
  • そもそも Branches (分岐) がクッソ多い…?
    • I refs に対してもデカすぎんだろ…
    • ただし、分岐予測ミスは 1.8% でそれほどミスってない
      • 母数がデカいので、絶対数は多い
  • Indirect branches (間接分岐) の予測ミスが 40% …?
    • なるほどね(わかってない)

アセンブリを見る

Compiler Explorerの共有リンク

https://godbolt.org/z/9qov5645W

Rust コード

                        let mut r = 1;
                        r *= a[i];
                        r %= p;
                        r *= a[j];
                        r %= p;
                        r *= a[k];
                        // ...

アセンブリ

.Lfunc_begin3:
    .cfi_startproc
    .cfi_personality 155, DW.ref.rust_eh_personality
    .cfi_lsda 27, .Lexception3
    push rbp
    .cfi_def_cfa_offset 16
    push r15
    .cfi_def_cfa_offset 24
    push r14
    .cfi_def_cfa_offset 32
    push r13
    .cfi_def_cfa_offset 40
    push r12
    .cfi_def_cfa_offset 48
    push rbx
    .cfi_def_cfa_offset 56
    sub  rsp, 312
    .cfi_def_cfa_offset 368
    .cfi_offset rbx, -56
    .cfi_offset r12, -48
    .cfi_offset r13, -40
    .cfi_offset r14, -32
    .cfi_offset r15, -24
    .cfi_offset rbp, -16

; ...

.Ltmp65:
    lea  rdi, [rsp + 32]
    mov  rsi, r13
    mov  rdx, rbp
    call qword ptr [rip + _ZN4core3num52_$LT$impl$u20$core..str..FromStr$u20$for$u20$i64$GT$8from_str17hed3d610c545d1817E@GOTPCREL]
.Ltmp66:
    cmp  byte ptr [rsp + 32], 1
    mov  rbx, qword ptr [rsp + 24]
    je   .LBB19_251
    mov  r13, qword ptr [rsp + 40]
    mov  rax, qword ptr [rsp]
    mov  edi, eax
    test dil, dil
    je   .LBB19_79
    jmp  .LBB19_107
    .p2align  4, 0x90

; ...

.LBB19_200:
    mov  qword ptr [rsp + 32], r15
    movaps   xmm0, xmmword ptr [rsp + 224]
    movups   xmmword ptr [rsp + 40], xmm0
    mov  dword ptr [rsp + 224], 0
    mov  rax, qword ptr [rsp + 168]
    add  rax, -4
    mov  qword ptr [rsp + 256], rax
    je   .LBB19_230
    movabs   rdi, -9223372036854775808

; ...

.LBB19_213:
    test r13, r13
    je   .LBB19_241
    mov  rax, qword ptr [r15 + 8*rsi]
    cmp  r13, -1
    jne  .LBB19_216
    cmp  rax, rdi
    je   .LBB19_243
.LBB19_216:
    cmp  rcx, r8
    jbe  .LBB19_234
    cqo
    idiv r13
    mov  rax, rdx
    imul rax, qword ptr [r15 + 8*r8]
    cmp  r13, -1
    jne  .LBB19_219
    cmp  rax, rdi
    je   .LBB19_240
.LBB19_219:
    cmp  rcx, rbp
    jbe  .LBB19_235
    cqo
    idiv r13
    mov  rax, rdx
    imul rax, qword ptr [r15 + 8*rbp]
    cmp  r13, -1
    jne  .LBB19_222
    cmp  rax, rdi
    je   .LBB19_242
.LBB19_222:
    ; ...

なるほど、なんもわからん。

ただし、Compiler Explorer は元のコードとコンパイル後コードで対応する行に同じ色を付けてくれます。また、-C debuginfo=1 オプションをつけると、各行に対して .loc が付与され、元の Rust コードとの対応関係が分かります*2

とりあえず歯を食いしばりながら読もうと努力してみると、以下のような事がわかります(?)

  1. 変数 r に対する 1 の代入と r *= a[i]; が最適化されている
    • let mut r = a[i]; のようなコードになった
  2. 64bit (qword ptr) の符号付き乗算 (imul) と 剰余計算 (cqo & idiv) が行われている
  3. 0除算チェックが1回行われている
    • r13core::num::<impl core::str::FromStr for i64>::from_str で読み込まれてる p (?)
      • idiv の引数に使われている
    • test r13, r13
      • AND 演算を行うので、r13 が 0 の時だけ ZF (ゼロフラグ)が立つ
  4. 剰余計算の前に毎回以下のような処理と jump (jne, je)するコードがある
    • cmp r13, -1
    • cmp rax, rdi

4点目ですが、MIR や r13rdi に格納されている値などから、このコードが除算の overflow チェックコードであることがわかります。

MIR:

    bb71: {
        _131 = Eq(_129, const -1i64);    // bb71[0]: scope 30 at src/main.rs:25:25: 25:31
                                         // ty::Const
                                         // + ty: i64
                                         // + val: Value(Scalar(0xffffffffffffffff))
                                         // mir::Constant
                                         // + span: src/main.rs:25:25: 25:31
                                         // + literal: Const { ty: i64, val: Value(Scalar(0xffffffffffffffff)) }
        _132 = Eq(_124, const std::i64::MIN); // bb71[1]: scope 30 at src/main.rs:25:25: 25:31
                                         // ty::Const
                                         // + ty: i64
                                         // + val: Value(Scalar(0x8000000000000000))
                                         // mir::Constant
                                         // + span: src/main.rs:25:25: 25:31
                                         // + literal: Const { ty: i64, val: Value(Scalar(0x8000000000000000)) }
        _133 = BitAnd(move _131, move _132); // bb71[2]: scope 30 at src/main.rs:25:25: 25:31
        assert(!move _133, "attempt to calculate the remainder with overflow") -> [success: bb72, unwind: bb39]; // bb71[3]: scope 30 at src/main.rs:25:25: 25:31
    }

このチェックは -C overflow-checks オプションが無効でも ( release ビルドでも) 行われます。

https://doc.rust-lang.org/reference/expressions/operator-expr.html#overflow

  • Using / or %, where the left-hand argument is the smallest integer of a signed integer type and the right-hand argument is -1. These checks occur even when -C overflow-checks is disabled, for legacy reasons.

なるほどつまり、Rust コードの %= p; の前に毎回ピョンピョン jump で分岐していたわけですね(?)

ちなみに、乗算の処理を1行コメントアウトすると、剰余の計算 (cqo & idiv) は1回しか行われませんが、overflow チェックは2回行われ、結果的に2回 jump するようです(えぇ…

Rust コード

                        let mut r = 1;
                        r *= a[i];
                        r %= p;
                        // r *= a[j];
                        r %= p;
                        r *= a[k];

アセンブリ

.LBB19_213:
    test r13, r13
    je   .LBB19_238
    mov  rax, qword ptr [r15 + 8*rsi]
    cmp  r13, -1
    jne  .LBB19_216
    cmp  rax, rdi
    je   .LBB19_240
.LBB19_216:
    cqo
    idiv r13
    mov  rax, rdx
    cmp  r13, -1
    jne  .LBB19_218
    cmp  rax, rdi
    je   .LBB19_241
.LBB19_218:
    cmp  rcx, rbx
    jbe  .LBB19_233
    imul rax, qword ptr [r15 + 8*rbx]
    cmp  r13, -1
    jne  .LBB19_221
    cmp  rax, rdi
    je   .LBB19_242
.LBB19_221:

想定解のコードと比較してみる

この問題の想定解は以下のようになっています。(C++のコードをRustで記述)

if a[i] * a[j] % p * a[k] % p * a[l] % p * a[m] % p == q {
    count += 1;
}

実際に計測してみましょう。

Code Explorerの共有リンク

https://godbolt.org/z/jhdj59rz4

実行結果

# cargo rustc --release -- -C llvm-args=--x86-asm-syntax=intel --emit asm --emit llvm-ir --emit mir
# time ./target/release/typical90_055
30980356

real    0m2.464s
user    0m2.464s
sys     0m0.000s
# valgrind --tool=callgrind --cache-sim=yes --branch-sim=yes ./target/release/typical90_055
==27== Callgrind, a call-graph generating cache profiler
==27== Copyright (C) 2002-2017, and GNU GPL'd, by Josef Weidendorfer et al.
==27== Using Valgrind-3.14.0 and LibVEX; rerun with -h for copyright info
==27== Command: ./target/release/typical90_055
==27==
--27-- warning: L3 cache found, using its data for the LL simulation.
==27== For interactive control, run 'callgrind_control -h'.
30980356
==27==
==27== Events    : Ir Dr Dw I1mr D1mr D1mw ILmr DLmr DLmw Bc Bcm Bi Bim
==27== Collected : 3303641251 460338373 31662022 1801 4184 1130 1703 2865 1000 911406936 17989987 2120 838
==27==
==27== I   refs:      3,303,641,251
==27== I1  misses:            1,801
==27== LLi misses:            1,703
==27== I1  miss rate:          0.00%
==27== LLi miss rate:          0.00%
==27==
==27== D   refs:        492,000,395  (460,338,373 rd + 31,662,022 wr)
==27== D1  misses:            5,314  (      4,184 rd +      1,130 wr)
==27== LLd misses:            3,865  (      2,865 rd +      1,000 wr)
==27== D1  miss rate:           0.0% (        0.0%   +        0.0%  )
==27== LLd miss rate:           0.0% (        0.0%   +        0.0%  )
==27==
==27== LL refs:               7,115  (      5,985 rd +      1,130 wr)
==27== LL misses:             5,568  (      4,568 rd +      1,000 wr)
==27== LL miss rate:            0.0% (        0.0%   +        0.0%  )
==27==
==27== Branches:        911,409,056  (911,406,936 cond +      2,120 ind)
==27== Mispredicts:      17,990,825  ( 17,989,987 cond +        838 ind)
==27== Mispred rate:            2.0% (        2.0%     +       39.5%   )

…まぁ、0.8s ほど早くはなったが…?アセンブリはどうでしょうか?

アセンブリ

.LBB19_213:
    cmp  rcx, r8
    jbe  .LBB19_232
    test r13, r13
    je   .LBB19_238
    mov  rax, qword ptr [r15 + 8*r8]
    imul rax, qword ptr [r15 + 8*rsi]
    cmp  r13, -1
    jne  .LBB19_217
    cmp  rax, rdi
    je   .LBB19_240
.LBB19_217:
    cmp  rcx, rbp
    jbe  .LBB19_233
    cqo
    idiv r13
    mov  rax, rdx
    imul rax, qword ptr [r15 + 8*rbp]
    cmp  r13, -1
    jne  .LBB19_220
    cmp  rax, rdi
    je   .LBB19_241
.LBB19_220:
    ; ...

特に代わり映えしない出力ですが、元のコードでやっていた最初の余計な %= p; の jump がなくなり、I refs や Branches が減って処理が速くなったようです。

では元のコードも同じように書き換えればいいのでは???と言うことで試してみると、確かに同じような実行時間で終わるようになりました。

Rust コード

                        let mut r = a[i];
                        r *= a[j];
                        r %= p;
                        r *= a[k];
                        r %= p;
                        // ...

実行結果

# time ./target/release/typical90_055
30980356

real    0m2.457s
user    0m2.456s
sys     0m0.000s

より速くする

とりあえずホットスポットで何度も剰余計算するのが良くないようです。ここで実行時間が 1s を切っている中で一番提出が早いコードを見てみましょう。

提出 #23104825

Rust コード

639 ms

    let mut ans = 0usize;
    for i in 4..n {
        let x = a[i];
        for j in 3..i {
            let x = x * a[j] % p;
            for k in 2..j {
                let x = x * a[k] % p;
                for l in 1..k {
                    let x = x * a[l] % p;
                    for m in 0..l {
                        let x = x * a[m] % p;
                        ans += (x == q) as usize;
                    }
                }
            }
        }
    }

なるほど、確かにこれなら剰余計算による jump の回数を減らせそうです。似たようなコードに書き換えて計測してみると以下のような結果になります。

Code Explorerの共有リンク

https://godbolt.org/z/nM1qzx1d1

実行結果

#  time ./target/release/typical90_055
30980356

real    0m0.603s
user    0m0.593s
sys     0m0.010s
# valgrind --tool=callgrind --cache-sim=yes --branch-sim=yes ./target/release/typical90_055
==25== Callgrind, a call-graph generating cache profiler
==25== Copyright (C) 2002-2017, and GNU GPL'd, by Josef Weidendorfer et al.
==25== Using Valgrind-3.14.0 and LibVEX; rerun with -h for copyright info
==25== Command: ./target/release/typical90_055
==25==
--25-- warning: L3 cache found, using its data for the LL simulation.
==25== For interactive control, run 'callgrind_control -h'.
30980356
==25==
==25== Events    : Ir Dr Dw I1mr D1mr D1mw ILmr DLmr DLmw Bc Bcm Bi Bim
==25== Collected : 1162983272 91910173 31671520 1799 4185 1130 1704 2866 1001 316953793 20069087 2120 840
==25==
==25== I   refs:      1,162,983,272
==25== I1  misses:            1,799
==25== LLi misses:            1,704
==25== I1  miss rate:          0.00%
==25== LLi miss rate:          0.00%
==25==
==25== D   refs:        123,581,693  ( 91,910,173 rd + 31,671,520 wr)
==25== D1  misses:            5,315  (      4,185 rd +      1,130 wr)
==25== LLd misses:            3,867  (      2,866 rd +      1,001 wr)
==25== D1  miss rate:           0.0% (        0.0%   +        0.0%  )
==25== LLd miss rate:           0.0% (        0.0%   +        0.0%  )
==25==
==25== LL refs:               7,114  (      5,984 rd +      1,130 wr)
==25== LL misses:             5,571  (      4,570 rd +      1,001 wr)
==25== LL miss rate:            0.0% (        0.0%   +        0.0%  )
==25==
==25== Branches:        316,955,913  (316,953,793 cond +      2,120 ind)
==25== Mispredicts:      20,069,927  ( 20,069,087 cond +        840 ind)
==25== Mispred rate:            6.3% (        6.3%     +       39.6%   )

I refs と Branches が 1/3 程度になり高速化されました。やったね。

本当にこれで終わり…?

さて…

  • AtCoderでは言語アップデートが水面下で進んでおり、言語テストが行えるコンテストページがある*3
  • そのコンテストでは Rust 1.70 が使える

ので、最初の「遅いコード」をコードテストで動かしてみると、なんと 400 ms ほどで処理が完了します。

実は、Rust 1.42 では opt-level = 3 で SEGV が発生する問題があり、release profile で opt-level = 2 になっていました。

では、Rust 1.43 で release ビルドしたら速くなるのでは?という疑問が出てくるので、実際に試してみると、速くなりません(!?)

というわけで(?)、Rust のバージョンを二分探索してどのバージョンから速くなるのか調べたところ… 1.59.0 から明らかに速くなりました。

計測結果 (1.59.0)

# time ./target/release/typical90_055
30980356

real    0m1.028s
user    0m1.028s
sys     0m0.000s
# valgrind --tool=callgrind --cache-sim=yes --branch-sim=yes ./target/release/typical90_055
==28== Callgrind, a call-graph generating cache profiler
==28== Copyright (C) 2002-2017, and GNU GPL'd, by Josef Weidendorfer et al.
==28== Using Valgrind-3.16.1 and LibVEX; rerun with -h for copyright info
==28== Command: ./target/release/typical90_055
==28==
--28-- warning: L3 cache found, using its data for the LL simulation.
==28== For interactive control, run 'callgrind_control -h'.
30980356
==28==
==28== Events    : Ir Dr Dw I1mr D1mr D1mw ILmr DLmr DLmw Bc Bcm Bi Bim
==28== Collected : 2168881667 260208883 371061 1690 4013 1127 1606 2842 1008 343119591 4343526 1648 438
==28==
==28== I   refs:      2,168,881,667
==28== I1  misses:            1,690
==28== LLi misses:            1,606
==28== I1  miss rate:          0.00%
==28== LLi miss rate:          0.00%
==28==
==28== D   refs:        260,579,944  (260,208,883 rd + 371,061 wr)
==28== D1  misses:            5,140  (      4,013 rd +   1,127 wr)
==28== LLd misses:            3,850  (      2,842 rd +   1,008 wr)
==28== D1  miss rate:           0.0% (        0.0%   +     0.3%  )
==28== LLd miss rate:           0.0% (        0.0%   +     0.3%  )
==28==
==28== LL refs:               6,830  (      5,703 rd +   1,127 wr)
==28== LL misses:             5,456  (      4,448 rd +   1,008 wr)
==28== LL miss rate:            0.0% (        0.0%   +     0.3%  )
==28==
==28== Branches:        343,121,239  (343,119,591 cond +   1,648 ind)
==28== Mispredicts:       4,343,964  (  4,343,526 cond +     438 ind)
==28== Mispred rate:            1.3% (        1.3%     +    26.6%   )

version 毎の結果

Rust version time
1.70.0 1.034s
1.59.0 1.028s
1.58.0 2.744s
1.42.0 3.228s

例えば LLVM IR の差分として、一番最後の剰余計算の処理が、1.58.0 では別のブロックに分かれてますが、1.59.0 では1つ前の剰余&乗算処理と同じブロックにまとまっています。ただし、これは出力されるアセンブリにほぼ影響を与えていないような気がします。

アセンブリの方はと言うと、1.59.0では p の 0 チェック (0除算回避チェック) が一番内側のループ開始時から、それよりも外側で行うように最適化するなど頑張っているようです。今回のコードでは main 関数に処理をすべてベタ書きしていますが、5重ループの処理を別の関数に分けると、関数に入った直後 (ループ開始前) に一度だけチェックするように更に最適化されました。*4

他にも色々最適化されてる可能性はありますが…なんもわからん。というかそもそも具体的になんの変更が影響してるのかリリースノートから分からなかった。

同様の事例

ABC323 E - Playlist

各曲が選択される確率は  \frac{1}{N} (確率 mod) 固定なので、事前に計算した値を使いまわせます。ですが各時刻における確率計算で毎回この剰余演算を行うと実行時間が極端に悪化します。

まとめ

  • Rust では除算&剰余計算で overflow チェックコードによる jump が挿入される
    • 他の演算に比べ、その分コストが高い
  • Rust 1.58 以前では剰余計算に対する最適化が弱い (?)
    • そのため、ホットスポットで何度も剰余計算するとパフォーマンスが低下する
      • ことがある?
    • 心当たりがあるコードでパフォーマンスが出ていない場合、計測して確認する価値がありそう
  • なるべく最新版の Rust を使おう
    • 前回 Rust の局所最適化が難しいみたいな感想書いたが、Rust 1.42 の最適化が弱いことも一因だったっぽい
    • 同じアセンブルを出力するように見えるコードでも、実際には出力が異なることがある
  • Rust なにもわからない
  • コンパイラはお前より頭が良いので、コンパイラの最適化を信じろ(はい)

参考リンク


*1:今回は調査中に何が必要となるか分からないので全部出力する

*2:Compiler Explorerではこのオプションが指定されている

*3:2023/07/24 時点

*4:https://godbolt.org/z/3eEjEecqh