マルチキークイックソートとstd::sortの比較

以前にCommon Lispで実装したマルチキークイックソートC++で書き直して、STLのstd::sortと速度を比較してみた。

使用データ

Wikipediaの記事タイトル約500万行を使用。
データ作成方法などはここを参照。

$ wc -l wiki.title.500
5340378 wiki.title.500

$ ls -lh wiki.title.500
-rw-r--r-- 1 user user 103M 2010-10-27 02:14 wiki.title.500

これを次のRubyスクリプトでシャッフルする。

# ファイル名: shuffle.rb
lines=open(ARGV[0]).read.split("\n")

2.times{|k|
  lines.size.times{|i|
    j = rand(lines.size)
    tmp = lines[i]
    lines[i] = lines[j]
    lines[j] = tmp
  }
}

lines.each{|l|
  puts l
}
$ ruby shuffle.rb wiki.title.500 > wiki.shuffle

$ head wiki.shuffle
西村知道
カービィのピンボール
Hudson_K〓rfezi
De_Sanctis-Cacchionen_syndrooma
倍数性
Arto_Kulmala

比較結果

実行速度の比較結果。
比較に用いたソースコードは末尾に掲載。

# マルチキークイックソート
$ ./mqsort 0 < wiki.shuffle > /dev/null
MULTIKEY QUICK SORT:
Elapsed: 4.58632 sec

# std::sort
$ ./mqsort 1 < wiki.shuffle > /dev/null
std::sort:
Elapsed: 9.90535 sec

この例ではマルチキークイックソートの方がニ倍程度早かった。

ソースコード

実装ソースコード
以前のCommon LispのそれをC++に直訳したもの。
チューニングすればもう少し早くなるかも*1

/**
 * ファイル名: mqsort.cc
 * コンパイル: g++ -O2 -o mqsort mqsort.cc   # gcc version 4.4.3 (Ubuntu 4.4.3-4ubuntu5)
 *
 * 標準入力から行を読み込み、ソートした結果を出力する。
 * ソート方法はマルチキークイックソートかstd::sort。
 * ※ コマンドの第一引数に0を渡せば前者、それ以外なら後者。
 *
 * Usage: mqsort [0 or 1] < [入力ファイル]  > [出力ファイル]
 */
#include <iostream>
#include <vector>
#include <string>
#include <cstring>
#include <algorithm>

/**
 * 計時用関数
 */
#include <sys/time.h>
inline double gettime(){
  timeval tv;
  gettimeofday(&tv,NULL);
  return static_cast<double>(tv.tv_sec)+static_cast<double>(tv.tv_usec)/1000000.0;
}


/**
 * マルチキークイックソート
 */
struct Range {
  Range(const char** b, const char** e) : begin(b), end(e) {}
  const char** begin;
  const char** end;
};

inline void swap_range(const char** beg1, const char** beg2, int count) {
  for(int i=0; i < count; i++)
    std::swap(beg1[i], beg2[i]);
}

inline unsigned char code(const char** ptr, int depth) {
  return (unsigned char)(*ptr)[depth];
}

inline void set_pivot_at_front(const char** beg, const char** end, int depth) {
  const char** mid = beg + (end-beg)/2;
  const char** las = end-1;
  unsigned char a = code(beg,depth);
  unsigned char b = code(mid,depth);
  unsigned char c = code(las,depth);

  if(a<b) {
    if(a<c) {
      b<c ? std::swap(*beg,*mid) : std::swap(*beg,*las);
    }
  } else {
    if(b<c) {
      std::swap(*beg,*mid);
    } else {
      if(a<c)
        std::swap(*beg,*las);
    }
  }
}

Range partition(const char** beg, const char** end, int depth) {
  set_pivot_at_front(beg, end, depth);
  
  unsigned char pivot = code(beg,depth);
  const char** ls_front = beg+1;
  const char** ls_last  = beg+1;
  const char** gt_front = end-1;
  const char** gt_last  = end-1;

  for(;;) {
    for(; ls_last <= gt_front && code(ls_last,depth) <= pivot; ls_last++)
      if(pivot == code(ls_last,depth)) {
        std::swap(*ls_front, *ls_last);
        ls_front++;
      }
    for(; ls_last <= gt_front && code(gt_front,depth) >= pivot; gt_front--)
      if(pivot == code(gt_front,depth)) {
        std::swap(*gt_front, *gt_last);
        gt_last--;
      }

    if(ls_last > gt_front)
      break;
    std::swap(*ls_last, *gt_front);
    ls_last++;
    gt_front--;
  }

  const char** ls_beg = ls_front;
  const char** ls_end = ls_last;
  const char** gt_beg = ls_last;
  const char** gt_end = gt_last+1;
  
  int len1 = std::min(ls_beg-beg, ls_end-ls_beg);
  swap_range(beg, ls_end-len1, len1);

  int len2 = std::min(end-gt_end, gt_end-gt_beg);
  swap_range(gt_beg, end-len2, len2);

  return Range(ls_end-(ls_beg-beg),
               gt_beg+(end-gt_end));
}

inline void swap_if_greater(const char*& x, const char*& y, int depth) {
  if(strcmp(x+depth, y+depth) > 0) 
    std::swap(x, y);
}

void mqsort_impl(const char** begin, const char** end, int depth) {
  int len = end-begin;
  if(len <= 2) {
    if(len == 2) 
      swap_if_greater(*begin, *(begin+1), depth);
  } else {
    Range r = partition(begin, end, depth);
    mqsort_impl(begin, r.begin, depth);
    if(code(r.begin,depth) != '\0')
      mqsort_impl(r.begin, r.end, depth+1);
    mqsort_impl(r.end, end, depth);
  }
}

void mqsort(const char** begin, const char** end) {
  mqsort_impl(begin, end, 0);
}

// std::sortに渡す比較関数
inline bool str_less_than(const char* a, const char* b) {
  return strcmp(a,b) < 0;
}

/**
 * main関数
 */
int main(int argc, char** argv) {
  // 標準入力から行読み込み
  std::vector<std::string> lines;
  std::string line;
  while(getline(std::cin,line)) 
    lines.push_back(line);

  // char*の配列に変換
  const char** words = new const char*[lines.size()];
  for(int i=0; i < lines.size(); i++)
    words[i] = lines[i].c_str();
  
  // ソート
  double beg_t = gettime();
  if(strcmp(argv[1],"0")==0) {
    std::cerr << "MULTIKEY QUICK SORT:" << std::endl;
    mqsort(words, words+lines.size());
  } else {
    std::cerr << "std::sort:" << std::endl;
    std::sort(words, words+lines.size(), str_less_than);
  }
  std::cerr << "Elapsed: " << gettime()-beg_t << " sec" << std::endl;

  // 結果出力
  for(int i=0; i < lines.size(); i++) 
    std::cout << words[i] << std::endl;

  return 0;
};

*1:クイックソートでは定番の、範囲が狭くなったら挿入ソートを使う、とか。