harryのブログ

ロードバイクとか模型とかゲームについて何か書いてあるかもしれません

バカ(俺)向け二分探索コードの書き方

なにこれ

  • たまに競プロで二分探索を手書きする必要があるが、その度に境界の扱いが分からなくなる(バカ)。
  • のでまとめた。
  • アルゴリズム実技検定 公式テキスト[上級]~[エキスパート]編 には二分探索の章が20ページほど割かれており、この記事以上に詳細が書かれているかも?
    • まだ買ってない(欲しい)。
  • この記事のコードはすべてRustです。
    • Rustは 0-based indexing です。

完全一致する値を探す

  • ソート済みのコレクション(Vec)に値が存在するか探すことを目的とする。
    • 特定の連続範囲に値が含まれているかは RangeRangeInclusive で容易に判別可能というか、自明なため…
  • 正確には完全一致する値のindexを探す。
  • Rustの場合、sliceに以下のメソッドが定義されているので、基本的にはこれを使えばよい。
    • [T].binary_search()
    • [T].binary_search_by()
    • [T].binary_search_by_key()
    • これらの実動作はrustdocで確認するのが早い

コード例?

他言語でよく書かれているコードをRustでベタ移植すると、以下のような感じ。

use std::cmp::Ordering;

pub fn search<T: Ord>(vec: &Vec<T>, x: &T) -> Option<usize> {
    let mut left = 0;
    let mut right = vec.len();

    while left <= right {
        let mid = (left + right) / 2;
        match vec[mid].cmp(x) {
            Ordering::Equal   => { return Some(mid); },
            Ordering::Less    => { left  = mid + 1; }, // x は mid より 大きい 範囲に存在する (はず)
            Ordering::Greater => { right = mid - 1; }, // x は mid より 小さい 範囲に存在する (はず)
        }
    }

    // left > right
    None
}
  • この実装では一致する値が複数あった場合、どれか1つの index が返ります。
  • 結果はOption型で返しています。
    • C++などでは見つからなかったら -1 を返しますが、Rust の index は符号なし整数型(非負整数)のため。
    • もちろん binary_search() のようにResult型で返してもよいです。

探索が停止する直前は以下のような状態になります。

0 left right (= left + 1) len

この時 mid は (2 * left + 1) / 2 の整数型演算により left となります。この mid が条件を満たさない場合、以下のどちらかで探索が停止します。

  • mid が x に対して小さい場合
    • left = mid + 1left == right になり、元々 right だった値でもう一度チェックされる
      • それで見つからなければ、 left > right になり停止
  • mid が x に対して大きい場合
    • right = mid - 1 で、left > right になり停止

コード例の問題点

前記のコード例には少なくとも以下の問題があります。

  1. Vecのサイズがusize型の上限値*1だった場合、(left + right) / 2 が overflow する可能性がある。
    • 実際には OOM(out of memory) の方で死ぬと思う…。
  2. 空のVecを渡した場合に vec[0] が参照され、index out of bounds の実行時エラーとなる。
  3. Vec内の最小値よりも小さい値を指定した場合、'attempt to subtract with overflow' の実行時エラーとなる。
    • right = mid - 1; でusize型(符号なし整数型)に対して 0 - 1 が実行されてしまう。

最初の問題点は二分探索のコードでありがちな (有名な) バグで、以下のようにすることで回避できます。

    // ...
    while left <= right {
        let len = (right - left) / 2;
        let mid = left + len;
    // ...

が、おそらく競プロにおいて問題になることはなく、この記事の本旨でもないので、今回は修正しない方針で統一します。

2つ目以降の問題を直す方法を少し考えてみます (3つ目の問題はRust特有ではありますが)。

とりあえず left と right が等しくなった時にループを抜けるようにします。既に見てきたように (left + right) / 2 は left と right が隣り合った時に left が返るので、left の更新時だけ mid + 1 にして確実にループが停止するようにします。

実はこれだけで修正が終わっています。

use std::cmp::Ordering;

pub fn search<T: Ord>(vec: &Vec<T>, x: &T) -> Option<usize> {
    let mut left = 0;
    let mut right = vec.len();

    while left < right {
        let mid = (left + right) / 2;
        match vec[mid].cmp(x) {
            Ordering::Equal   => { return Some(mid); },
            Ordering::Less    => { left  = mid + 1; }, // x は mid より 大きい 範囲に存在する (はず)
            Ordering::Greater => { right = mid; },     // x は mid より 小さい 範囲に存在する (はず)
        }
    }

    // left == right
    None
}

一応、探索が停止する直前の状態を見てみます。この状態では mid == left です。

0 left right (= left + 1) len
  • mid が x に対して小さい場合
    • left = mid + 1left == right になり停止 ( None が返る)
  • mid が x に対して大きい場合
    • right = mid で、left == right になり停止 ( None が返る)

私はバカなので「もし right が答えだったらどうするんだよ」と考えてしまいますが、right の初期値にはしれっと答えの範囲外である vec.len() が入っているので、実は問題になりません。

  • right が vec.len() 以外の値だった場合
    • そもそも right の値は一致していないので、チェックは不要。
  • right が vec.len() だった場合
    • left が vec.len() - 1 なので、この値が一致しなかったら None でよい。

条件を満たす最大値を探す

さて少し話を変えて、1つの値から真偽値(bool)を返す関数を定義した時、それを満たす 最大の値 を範囲内から二分探索することを考えます。二分探索のシグネチャは以下の通りです。

pub fn search<F>(min: i32, max: i32, f: F) -> i32
where
    F: Fn(&i32) -> bool

こちらは数値の範囲内で二分探索を行うので、Vec(メモリ)に収まらない範囲を高速に探索する必要がある場合、特に競プロで答えを二分探索する際に使います*2

最大値を探す前提として、定義する関数は指定範囲内(min~max)で以下のように規則正しく真偽値を返すものとします。

v min x-1 x x + 1 x + 2 max
f(v) true true true false false false
探したい
最大値

f(min) == false の場合、この範囲内において求めるべき値は存在しません。

値を探す二分探索

二分探索で left == right となるような値を1つ探すことを考えます。

コードは以下のような感じ。

pub fn search<F>(min: i32, max: i32, f: F) -> i32
where
    F: Fn(&i32) -> bool
{
    let mut left = min;
    let mut right = max;

    while left < right {
        let mid = (left + right + 1) / 2;
        if f(&mid) {
            left = mid;      // f() を 満たす値 の最大値は mid 以上
        } else {
            right = mid - 1; // f() を 満たす値 は mid より 小さい 範囲にある (はず)
        }
    }

    // left == right
    left
}

探索が停止する直前は以下のような状態になります。

min left (= right - 1) right max
条件を満たす 条件を満たさない

この時 mid の計算式が (left + right + 1) / 2 であることに注意すると、mid = right (= 2 * right / 2 ) になり、f(&mid) の戻り値がどちらだったとしても left == right で探索が停止します。

  • max - 1 までがすべて条件を満たす場合でも、最後に max (right) に対するチェックが行われます。
    • f(max) が true: max が返る
    • f(max) が false: max - 1 が返る
  • ただし、left == min のまま探索が終了した場合は、以下のどちらかです。
    • min だけが条件を満たす
    • 範囲内の値がすべて条件を満たさない

そのため、厳密には while ループ後に left (min) が条件を満たすかどうかチェックした方がよいです。が、min が与えられた制約の中で取りえる最小値であることが自明なら特に問題はなく、競プロでは気にする必要がないことがほとんどだと思います(死亡フラグ)。まぁ、max だった場合もこの範囲内での最大値がそれ、という話なので。

境界を探す二分探索

「値を探す二分探索」とほぼ同じですが、ここでは left と right が隣り合った時、つまり left と right の差が1の時に探索を停止することを考えます。

コードは以下のような感じ。

pub fn search<F>(min: i32, max: i32, f: F) -> i32
where
    F: Fn(&i32) -> bool
{
    let mut left = min;
    let mut right = max;

    while right - left > 1 {
        let mid = (left + right) / 2;
        if f(&mid) {
            left = mid;  // f() を 満たす値 の最大値は mid 以上
        } else {
            right = mid; // f() を 満たす値 は mid より 小さい 範囲にある (はず)
        }
    }

    // left + 1 == right
    left + i32::from(f(&right))
}

探索が停止する直前、left と right は隣り合ってないので、以下の状態です。

min left mid right max
満たす ? 満たさない

f(&mid)の結果によらず、次の処理で left と right が隣り合う状態になり探索が停止します。

ただし right == max のまま while ループを抜けた場合、right が求めるべき答えの可能性があるのでチェックが必要です。Rust には false/true から 0/1 を返す i32::from(small: bool)usize::from(small: bool) があります。これを使って right が条件を満たす場合には left に1を加算して right が返るようにしています。

while の終了条件や left / right の更新方法などが前述の「値を探す二分探索」のコードと異なりますが、同じような結果が得られました。個人的には、mid の算出方法や left/right の更新方法はこちらの方が少し簡素な印象があります。

条件を満たす最小値を探す

では次に、1つの値から真偽値(bool)を返す関数を定義した時、それを満たす 最小の値 を範囲内から二分探索することを考えます。

前提として、定義する関数は指定範囲内(min~max)で以下のように規則正しく真偽値を返すものとします。

v min x-2 x-1 x x + 1 max
f(v) false false false true true true
探したい
最小値

f(max) == false の場合、この範囲内において求めるべき値は存在しません。

値を探す二分探索

同じように実装すると以下のような感じ。

pub fn search<F>(min: i32, max: i32, f: F) -> i32
where
    F: Fn(&i32) -> bool
{
    let mut left = min;
    let mut right = max;

    while left < right {
        let mid = (left + right) / 2;
        if f(&mid) {
            right = mid;   // f() を 満たす値 の最小値は mid 以下
        } else {
            left = mid + 1; // f() を 満たす値 は mid より大きい範囲にある (はず)
        }
    }

    // left == right
    right
}

最終的に left == right なのでどちらの値を返してもいいのですが、探している値が右側なので right を返しています。

探索が停止する直前は以下のような状態になります。

min left (= right - 1) right max
条件を満たさない 条件を満たす

この時 mid == left になり、f(&mid) の戻り値がどちらだったとしても left == right で探索が停止します。

  • min + 1 までがすべて条件を満たす場合でも、最後に min (left) に対するチェックが行われます。
    • f(min) が true: min が返る
    • f(min) が false: min + 1 が返る
  • ただし、right == max のまま探索が終了した場合、以下のどちらかです。
    • max だけが条件を満たす
    • 範囲内の値がすべて条件を満たさない

こちらも同様に、min/max に正しい値が渡されている限りは、問題になりません。

境界を探す二分探索

left と right が隣り合った時に while ループを抜けるように実装すると、以下のような感じ。

pub fn search<F>(min: i32, max: i32, f: F) -> i32
where
    F: Fn(&i32) -> bool
{
    let mut left = min;
    let mut right = max;

    while right - left > 1 {
        let mid = (left + right) / 2;
        if f(&mid) {
            right = mid; // f() を 満たす値 の最小値は mid 以下
        } else {
            left = mid;  // f() を 満たす値 は mid より大きい範囲にある (はず)
        }
    }

    // left == right - 1
    right - i32::from(f(&left))
}

探索が停止する直前、left と right は隣り合ってないので、以下の状態です。

min left mid right max
満たさない ? 満たす

f(&mid) の結果によらず、次の処理で left と right が隣り合う状態になり探索が停止します。

ただし left == min のまま while ループを抜けた場合、left が条件を満たすかどうかチェックする必要があり、満たす場合は right - 1 として left を返します。

競プロの小手先テクニック

変数名を少し工夫する

時間制限など緊張感がある中でコーディングしていると、小さいミスが命取りになります。特に文字列や二分探索を扱う問題の場合、lr などの変数を使った際に左右誤認で変数を取り違えてハマるなどの事故が起きます(バカ)(n敗)。

とりあえずの対策として、状況により以下のような変数を使うことを考えます。

  • start - end ( s - e )
  • begin - end ( b - e )
  • from - to ( f - t )
  • lb - ub
    • lower_bound - upper_bound
    • 蟻本で使われている。
  • ac - wa
    • AtCoder 公式解説放送をされてる snuke さんが使っている。
  • ok - ng
    • 探したい値(境界)により、left 側を ok にしたり ng にしたり。

※ もちろん問題によって left - right の方が理解しやすい場合もあります。

「境界を探す二分探索」のコードを簡略化する

この記事で言うところの「境界を探す二分探索」は以下の点で魅力的です。

  • mid の計算は脳死(left + right) / 2 でよい
  • left/right には脳死mid を入れとけばよい

ただし、厳密には最後で ± i32::from(small: bool) のようなコードで必要で、このコードは書いても書き忘れてもバグとなる可能性があります。また、コンテスト中にそれを書くのは労力と時間がかかります。

それさえなければ、特にRust 1.60.0 以降で、while文以下は完全な定型コードにできます*3。早速 left/right を ok/ng に変えて書いてみると以下のようになります。

    while ok.abs_diff(ng) > 1 {
        let mid = (ok + ng) / 2;
        if f(&mid) {
            ok = mid;
        } else {
            ng = mid;
        }
    }

    ok

2023/06/09 18:50 追記
この二分探索は「めぐる式二分探索法」と呼ばれているようです。
追記ここまで

できれば実戦でこの簡略化したコードを使いたいところです。そこで答えとなる値の最小値/最大値(範囲)は分かっている前提で、探索範囲を以下の様にして、± i32::from(small: bool) のコードを省略してしまいます。

  • 最大値 を探す場合、範囲の 最大値を max + 1 にする。
  • 最小値 を探す場合、範囲の 最小値を min - 1 にする。

それぞれ端の境界まで処理が進んだ時、以下のような状態になり、簡略化したコードでも正しい値が返ってきます。

最大値を探す場合

ok = minng = max + 1 から始まります。( ok < ng )

変数 ok mid ng
v max - 1 max max + 1
f(v) true ? ?

f(&mid) が true なら ok = mid になり、false なら ng = mid になって終了します。max より大きい値が返ることはありません。

最小値を探す場合

ng = min - 1ok = max から始まります。( ng < ok )

変数 ng mid ok
v min - 1 min min + 1
f(v) ? ? true

f(&mid) が true なら ok = mid になり、false なら ng = mid になって終了します。min より小さい値が返ることはありません。

  • 最小値/最大値を変えて範囲を広げているので、探索範囲の幅は2以上であることが保障されます。
    • 幅が2の場合、while ループは実行されずにokの値が返って終わります。
  • 別の覚え方として、初期値で ng 側に来る範囲を1だけ広げる、と言うのがありそう。

例題

コレクション内を二分探索

標準ライブラリやAtCoder環境で提供されているライブラリで二分探索すればよい問題です。

数値の範囲内を二分探索

コレクションに収まらない範囲を二分探索する問題です。

最大値/最小値を二分探索

答えとなる最大値/最小値を二分探索する問題です。俗に言う「答えで二分探索」。

まとめ

いかがでしたか?二分探索については結局よく分かりませんでした。

二分探索の問題の解説は「二分探索を使って解きます!回答例はこれ!」で終わることが多く、何も見ずに回答例のコードを書けるようになるには「ミッシングリンク」がある気がしたのでまとめてみました。本当は物理的なノートにまとめようと思ったのですが、一発で清書できる気が全くしなかったので、ブログの記事にしました。

二分探索コードのループ条件や境界値がどのような動きになるか分からなくなった時は、ループの終了条件や終了直前の状態*4を可視化して整理してみてください。

この個人メモが誰かしらの役に立てば幸いです。

Appendix

二分探索の計算回数表

left と right が隣り合うまで (幅が2以下になるまで) の最悪計算回数一覧

データ量  \displaystyle N 最悪計算回数  \displaystyle \log_2 N
 \displaystyle 10^{5} 16 16.61
 \displaystyle 10^{6} 19 19.93
 \displaystyle 10^{9} 29 29.90
 \displaystyle 10^{12} 39 39.86
 \displaystyle 10^{15} 49 49.83
 \displaystyle 10^{18} 59 59.79

base(left) と size による二分探索の実装

探索の起点(base, left)を更新しつつ、探索範囲(size)を半分にしていく実装方法もあります。mid = base + size / 2 となるので、オーバーフローを気にしないで済みます。

例えば、Rustで書かれたライブラリのsuperslice-rsは以下の様になっています。

superslice-rs/src/lib.rs at 0.1.0 · alkis/superslice-rs · GitHub

    fn lower_bound_by<'a, F>(&'a self, mut f: F) -> usize
    where
        F: FnMut(&'a Self::Item) -> Ordering,
    {
        let s = self;
        let mut size = s.len();
        if size == 0 {
            return 0;
        }
        let mut base = 0usize;
        while size > 1 {
            let half = size / 2;
            let mid = base + half;
            let cmp = f(unsafe { s.get_unchecked(mid) });
            base = if cmp == Less { mid } else { base };
            size -= half;
        }
        let cmp = f(unsafe { s.get_unchecked(base) });
        base + (cmp == Less) as usize
    }

Rustの binary_search では match syntax が使われていないというお話

match syntaxが比較処理の順序を変えてしまい、これが performance sensitive (perf sensitive) であるため、とコメントに書かれています。

https://github.com/rust-lang/rust/blob/1.70.0/library/core/src/slice/mod.rs#L2508-L2548

            // The reason why we use if/else control flow rather than match
            // is because match reorders comparison operations, which is perf sensitive.
            // This is x86 asm for u8: https://rust.godbolt.org/z/8Y8Pra.
            if cmp == Less {
                left = mid + 1;
            } else if cmp == Greater {
                right = mid;
            } else {
                // SAFETY: same as the `get_unchecked` above
                unsafe { crate::intrinsics::assume(mid < self.len()) };
                return Ok(mid);
            }

C++lower_bound()upper_bound()

これらは二分探索で指定要素を検索するので、実は指定要素より前にある値が指定要素未満(または指定要素以下)になっていれば意図した通りに動作します。

    // lower_bound で 4 以上の要素の位置を検索する場合、
    // 4 より小さい物と 4 以上の物がその順に並んでいれば、
    // 必ずしもソートされている必要はない。
    std::vector<int> v = {3, 1, 4, 6, 5};

    // upper_bound で 3 より大きい要素の位置を検索する場合、
    // 3 以下の物と 3 より大きい物がその順に並んでいれば、
    // 必ずしもソートされている必要はない。
    std::vector<int> v = {3, 1, 4, 6, 5};

参考文献など

*1:32bit環境であれば u32 の上限値、64bit環境であれば u64 の上限値

*2:使えているとは言っていない

*3:AtCoder は 2023/06 時点で 1.42.0 です…

*4:あるいは極端に狭い範囲での開始状態