crubyのFloatにFast inverse square rootを(モンキーパッチで)追加する話

Fast inverse square root自体については「30のプログラミング言語でFast inverse square rootを実装してみました!」( http://itchyny.hatenablog.com/entry/2016/07/25/100000 )を参照してください。

せっかくの高速な手法ですから、高速な実装を考えてみます。Rubyの場合、まっとうに実装するのであれば Math モジュールに追加するべきですが、性能のことを考えると関与するオブジェクトが少ないほうが良いですから、Float クラスにモンキーパッチで追加することにしました。

こんな感じになります。

#include <stdint.h>
#include <float.h>
#include "ruby.h"

static VALUE
fast_inv_sqrt(VALUE v)
{
  double d = RFLOAT_VALUE(v);
  if ((d > 0.0) && (d <= DBL_MAX)) {
    //NOP
  } else {
    rb_raise(rb_eFloatDomainError, "out of domain");
  }
  union u_di {
    double d;
    uint64_t i;
  } d_i;
  d_i.d = d;
  d_i.i = 0x5fe6eb50c7b537a8ULL - (d_i.i >> 1);
  double d2 = d / 2.0;
  double g = d_i.d;
  d_i.d = g * (1.5 - d2*g*g);
#if 0
  d_i.i += 0x4e1c39a47a8ULL;
  g = d_i.d;
  d_i.d = g * (1.5 - d2*g*g);
  d_i.i += 0x14b9e8a78ULL;
#endif
  g = d_i.d;
  return DBL2NUM(g);
}

void
Init_fisqrt(void)
{
  rb_define_method(rb_cFloat, "fast_inv_sqrt", fast_inv_sqrt, 0);
}

ポイントとなる点をいくつか説明してゆきます。

まず Init_fisqrt が拡張ライブラリのエントリポイントで、これは見たままかと思います。

関数 fast_inv_sqrt が本体で、最初と最後にある VALUE と double の相互変換は拡張ライブラリの書き方の定番通りです。

続いて引数の範囲チェックですが、浮動小数点数のチェックは「真になる条件」でチェックするのがセオリーです(比較に NaN が関わるると偽になるため)。

ビットパターンを保存した浮動小数点数と整数の変換に union を使っているのも、この手のトリックの定番と言っていいでしょう(標準において、未定義ではなく処理系定義のため、気休め程度だがマシ)。

単精度の時の 0x5f3759df に相当するマジックナンバーは、妥当と思われる方法で私が探索(詳細は略)したものです。

#if 0 でコメントアウトしてある部分は、せっかくなので高精度化を図ってみたコードですが、これがあると **(-0.5) で正確な値を計算するよりも遅くなってしまったのでコメントアウトしました。単にニュートン法をもう1回繰返すだけではなく、誤差で必ず負の側に予測可能な範囲でズレることがわかっているので、それを調整する加算も入れてあります。

続いてベンチマークに使ったスクリプトを示します。

#! /usr/local/bin/ruby23

require 'benchmark'

require_relative 'fisqrt'

a = 1.0

result = Benchmark.realtime {
  (0..10000000).each {|i|
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
    a.fast_inv_sqrt
  }
}
print "result: #{result}s\n"

result = Benchmark.realtime {
  (0..10000000).each {|i|
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
    a**(-0.5)
  }
}
print "result: #{result}s\n"

result = Benchmark.realtime {
  (0..10000000).each {|i|
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
    1.0 / Math.sqrt(a)
  }
}
print "result: #{result}s\n"

次のようになります。わずかですが速いという結果が出ています。

result: 6.402902246918529s
result: 7.31780265388079s
result: 9.100973834982142s