Julia では argmin とか argmax などの関数が用意されていたのですが、 Rust で同様の処理をする方法がわからず毎回調べて時間を消耗していたので、まとめておこうと思います。

もっと簡単な書き方があれば教えてください。


fld, cld

fld/を単に使えば良い気もします。

fn main(){
    let a: i64 = 8;
    let b = 3;
    assert_eq!(num::Integer::div_floor(&a, &b), 2);
}
fn main(){
    let a: i64 = 8;
    let b = 3;
    assert_eq!(num::Integer::div_ceil(&a, &b), 3);
}

broadcast

map を使えば良い。

fn main() {
    let v = [1, -2, 3, -4, 5];
    let x = v.map(|x: i64| x.abs());

    assert_eq!(x, [1, 2, 3, 4, 5]);
}
fn main() {
    let v = vec![1, -2, 3, -4, 5];
    let x = v.iter().map(|&x:&i64| x.abs()).collect::<Vec<_>>();

    assert_eq!(x, vec![1, 2, 3, 4, 5]);
}

diff

fn main(){
    let a = vec![3, -1, 0];
    let d: Vec<i64> = a.windows(2).map(|w| w[1] - w[0]).collect();
    assert_eq!(d, [-4, 1]);
}

findmin(itr)

fn main() {
    let v = [1, 7, 7, 6];
    let x = v.iter().enumerate().min_by_key(|&(_, x)| x);
    
    assert_eq!(x, Some((0, &1)));
}

(index, x) の順になる。min_by_key は最小値を達成するインデックスが複数ある場合、最初のインデックスを返すところは findmin と同じですね。


findmax(itr)

max_by_key を使って書くと

fn main() {
    let v = [1, 7, 7, 6];
    let x = v.iter().enumerate().max_by_key(|&(_, x)| x);
    
    assert_eq!(x, Some((2, &7)));
}

(index, x) の順になる。max_by_keyは最大値が達成するインデックスが複数ある時、最後のインデックスが帰ってくるので、 最初のインデックスを返す findmax とは挙動が異なる。

use core::cmp::Reverse;

fn main() {
    let v = [1, 7, 7, 6];
    let x = v.iter().enumerate().min_by_key(|&(_, x)| Reverse(x));
    
    assert_eq!(x, Some((1, &7)));
}

とすれば最初のインデックスを返せる。


argmin(itr)

findmin とほぼ同じ。

fn main() {
    let v = [1, 7, 7, 6];
    let x = v.iter().enumerate().min_by_key(|&(_, x)| x).map(|x| x.0);
    
    assert_eq!(x, Some(0));
}

argmax(itr)

findmax とほぼ同じ。最大値を取る添字が複数ある場合に最後のインデックスを返したければ、

fn main() {
    let v = [1, 7, 7, 6];
    let x = v.iter().enumerate().max_by_key(|&(_, x)| x).map(|x| x.0);
    
    assert_eq!(x, Some(2))
}

最初としたければ、

use core::cmp::Reverse;

fn main() {
    let v = [1, 7, 7, 6];
    let x = v
        .iter()
        .enumerate()
        .min_by_key(|&(_, x)| Reverse(x))
        .map(|x| x.0);

    assert_eq!(x, Some(1));
}

cumsum(itr)

itertools_num を使ってよければ、

use itertools_num::ItertoolsNum;

fn main() {
    let v = [1, 1, 1];
    let x: Vec<i64> = v.iter().cumsum().collect();

    assert_eq!(x, vec![1, 2, 3]);
}

最初に0を入れ込みたい場合は、

fn main() {
    let v = [1, 1, 1];

    let mut x = vec![0];
    for i in 0..v.len() {
        x.push(x[i] + v[i]);
    }

    assert_eq!(x, vec![0, 1, 2, 3]);
}

cumprod(itr)

fn main() {
    let v = [1, 2, 3];

    let mut x = vec![1];
    for i in 0..v.len() {
        x.push(x[i] * v[i]);
    }

    assert_eq!(x, vec![1, 1, 2, 6]);
}

sortperm

fn main() {
    let v = [-5, 4, 1, -3, 2];
    let mut x = (0..v.len()).collect::<Vec<_>>();
    x.sort_by_key(|&i| &v[i]);

    assert_eq!(x, [0, 3, 2, 4, 1]);
}

sortがunstableでよければ、

fn main() {
    let v = [-5, 4, 1, -3, 2];
    let mut x = (0..v.len()).collect::<Vec<_>>();
    x.sort_unstable_by_key(|&i| &v[i]);

    assert_eq!(x, [0, 3, 2, 4, 1]);
}

searchsortedfirst

1.52.0 以降だと partition_point が使える。

fn main() {
    let v = [1, 2, 4, 5, 5, 7];

    let i = v.partition_point(|&x| !(x >= 4));
    assert_eq!(i, 2);

    let i = v.partition_point(|&x| !(x >= 5));
    assert_eq!(i, 3);

    let i = v.partition_point(|&x| !(x >= 3));
    assert_eq!(i, 2);

    let i = v.partition_point(|&x| !(x >= 9));
    assert_eq!(i, 6);

    let i = v.partition_point(|&x| !(x >= 0));
    assert_eq!(i, 0);
}

AtCoder で使える Rust 1.42.0 だと工夫して

use std::cmp::Ordering::{Greater, Less};

fn main() {
    let v = [1, 2, 4, 5, 5, 7];

    let i = v
        .binary_search_by(|&x| if !(x >= 4) { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 2);

    let i = v
        .binary_search_by(|&x| if !(x >= 5) { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 3);

    let i = v
        .binary_search_by(|&x| if !(x >= 3) { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 2);

    let i = v
        .binary_search_by(|&x| if !(x >= 9) { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 6);

    let i = v
        .binary_search_by(|&x| if !(x >= 0) { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 0);
}

searchsortedlast

1.52.0 以降だと partition_point が使える。 インデックスに注意。

fn main() {
    let v = [1, 2, 4, 5, 5, 7];

    let i = v.partition_point(|&x| x <= 4);
    assert_eq!(i, 3);

    let i = v.partition_point(|&x| x <= 5);
    assert_eq!(i, 5);

    let i = v.partition_point(|&x| x <= 3);
    assert_eq!(i, 2);

    let i = v.partition_point(|&x| x <= 9);
    assert_eq!(i, 6);

    let i = v.partition_point(|&x| x <= 0);
    assert_eq!(i, 0);
}
use std::cmp::Ordering::{Greater, Less};

fn main() {
    let v = [1, 2, 4, 5, 5, 7];

    let i = v
        .binary_search_by(|&x| if x <= 4 { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 3);

    let i = v
        .binary_search_by(|&x| if x <= 5 { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 5);

    let i = v
        .binary_search_by(|&x| if x <= 3 { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 2);

    let i = v
        .binary_search_by(|&x| if x <= 9 { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 6);

    let i = v
        .binary_search_by(|&x| if x <= 0 { Less } else { Greater })
        .unwrap_or_else(|i| i);
    assert_eq!(i, 0);
}

usize で 絶対値をを取るとき (abs_diffが入っていないバージョンの時)

fn abs_diff<T>(a: T, b: T) -> T
where
    T: Sub<Output = T> + PartialOrd,
{
    if a > b {
        a - b
    } else {
        b - a
    }
}