読者です 読者をやめる 読者になる 読者になる

DoubleArray + AC法

C++ algorithm speed

AC法(エイホ-コラシック法)というものを知ったのでDoubleArrayで試してみた。

概要

入力テキストから、辞書(= 文字列セット)に登録されている文字列を効率的に検索するための方法らしい*1
詳細は、Wikipedia『エイホ-コラシック法』の項目を参照。


自分の理解としては、基本的には通常のトライと同じだけど、各ノードが自身が表す文字列(ルートからのパス)サフィックス(接尾文字列)に対応するノードへのリンクを有する点が異なる、といった感じ。
以下は、既存のトライに対して、上述のリンクを付加するための擬似コード

# Ruby形式の擬似コード
trie.traverse_all_node do |path, node|   # 全ノードを探索: path = ルートからnodeまでの文字を繋げた文字列
  suffix_link     = trie.root_node   # 接尾文字列に対応するノード (初期値はルートノード)
 
  # 接尾文字列の内で、対応するノードがある最長のものを探す
  while path.size > 1 && suffix_link == trie.root_node
    path = path[1..-1]                   # 先頭の文字を除去した接尾文字列を作成する
    suffix_link = trie.search_node(path) # 接尾文字列に対応するノードを探す (見つからなければ、trie.root_nodeが返る)
  end

  node.suffix_link     = suffix_link
end

サフィックスリンクがない通常のトライの場合、入力テキストにマッチする全ての文字列を検索しようとすると、次のようにテキスト内の各位置でcommon-prefix-searchを行う必要がある。
このため最悪の場合は、入力テキストの長さ x トライ内の最大の文字列の長さ、の処理ステップ(?)が必要となる(おそらく)

# 入力文字列が"abcdefg"の場合
1] "abcdefg"で検索   ※ 一回の検索で、何文字辿る必要があるかは、トライの内容に左右される
2] "bcdefg"で検索
3] "cdefg"で検索
4] "defg"で検索
5] "efg"で検索
6] "fg"で検索
7] "g"で検索
終了

対して、サフィックスリンク付きのトライでは、検索処理は次のようになる。※ かなり単純化している

# 入力文字列が"abcdefg"の場合
1] ルートノードから"a"に対応するノードに遷移 => サフィックスリンク※1を辿り、マッチングがあるかを調べる
2]  上のノードから"b"に対応するノードに遷移※2 => サフィックスリンクを辿り、マッチングがあるかを調べる
3]  上のノードから"c"に対応するノードに遷移 => サフィックスリンクを辿り、マッチングがあるかを調べる
4]  上のノードから"d"に対応するノードに遷移 => サフィックスリンクを辿り、マッチングがあるかを調べる
5]  上のノードから"e"に対応するノードに遷移 => サフィックスリンクを辿り、マッチングがあるかを調べる
6]  上のノードから"f"に対応するノードに遷移 => サフィックスリンクを辿り、マッチングがあるかを調べる
7]  上のノードから"g"に対応するノードに遷移 => サフィックスリンクを辿り、マッチングがあるかを調べる
終了

# ※1 Wikipediaの記事では、ここで辿るサフィックスリンクを「辞書サフィックスリンク」と呼んでいる
#
#     ただのサフィックスリンクは、末尾文字列に対応する(最長の)トライ内の任意のノードへのリンクであるのに対して、
#     辞書サフィックスリンクにはさらに、辞書に登録されている文字列に対応するノード、へのリンクでなければならないという制限がつく。
#
#     この制限のために、上でリンクを辿る回数は、マッチングがあるノードの数と等しくなり、必要最小限に抑えられる。
#
#     上の擬似コードには「辞書サフィックスリンク」用の記述が抜けている

# ※2 実際には対応する遷移先にノードがないため遷移できないことがある
#     その場合は、遷移元ノードのサフィックスリンクを辿り、そこから再度遷移が可能かどうかを試す必要がある 
#
#     下の処理ステップに関する記述では、このための処理に掛かるコストは無視している

この場合の処理ステップ(?)は、入力テキストの長さ + マッチングの数、となりリンクなしの場合に比べて低く抑えることが可能となる(おそらくこれで大きくは間違っていないはず...)
これがAC法の利点(多分)

実装1: 比較用トライ検索

ここから実装に入る。
まずは比較用に通常のトライ(DoubleArray)を用いた辞書検索コマンドを作成する。
マッチする文字列を取り出す(or 取り出し表示する)処理は、そのためのオーバヘッドが大きくなる可能性があるので、今回はマッチ数をカウントするだけにする。
 ※ ここで使用しているDoubleArrayの実装やmkdaコマンドに関しては『基本となるDoubleArrayの実装』を参照

/***
 * ファイル名: count-word.cc
 * コンパイル: g++ -O2 -o count-word count-word.cc
 *
 * 入力テキスト内で辞書(= mkdaコマンドが作成したDoubleArrayインデックス)にマッチする文字列の数をカウントする
 */
#include <iostream>
#include <sys/time.h>
#include "double_array.h"
#include "char_stream_vector.h"

// 計時用関数
inline double gettime(){
  timeval tv;
  gettimeofday(&tv,NULL);
  return static_cast<double>(tv.tv_sec)+static_cast<double>(tv.tv_usec)/1000000.0;
}

// マッチ数カウント用のクラス
struct Counter {
  Counter() : count(0) {}
  void operator()(const char* key, unsigned length, int id) {
    //std::cout << id << "#" << std::string(key, length) << std::endl;
    count++;
  }
  unsigned count;
};

int main(int argc, char** argv) {
  if(argc != 3) {
    std::cerr << "Usage: count-word <index> <text>" << std::endl;
    return 1;
  }
  
  DoubleArray da(argv[1]);
  CharStreamVector csv(argv[2]);
  Counter fn;

  // 入力テキストの各行を読み込み、辞書に登録されている文字列の数をカウントする
  double beg_t = gettime();
  for(unsigned i=0; i < csv.size(); i++) {
    const char* line = csv[i].rest();
    for(; *line != '\0'; line++)
      da.each_common_prefix(line, fn);
  }

  std::cerr << "Matched word count: " << fn.count << std::endl;
  std::cerr << "Elapsed: " << gettime()-beg_t << " sec" << std::endl << std::endl;

  return 0;
}

実行。
辞書にはWikipediaの各国語のタイトル(取得方法はここを参照)MeCabのサイトより配布されているIPADICから単語名を抽出したものを使用。

# データ準備
$ mkda wiki.idx wiki.title
$ nkf -w mecab-ipadic-2.7.0-20070801/*.csv | cut -d',' -f1 | LC_ALL=C sort | LC_ALL=C uniq > ipa.word
$ mkda ipa.idx ipa.word
$ ls *.idx
-rw-r--r-- 1 user user  11M 2010-07-03 18:29 ipa.idx
-rw-r--r-- 1 user user 391M 2010-07-03 18:26 wiki.idx

# カウント: 入力テキストには、夏目漱石の『こころ』(青空文庫より入手. UTF-8に変換)を使用
$ ./count-word wiki.idx kokoro
Word count: 202732
Elapsed: 0.013299 sec

$ ./count-word ipa.idx kokoro
Word count: 269405
Elapsed: 0.0121629 sec

実装2: AC法

AC法。
トライの部分は通常のDoubleArrayと同様で、各ノードに追加情報を付与する、という実装になっている。

まずは、追加情報を付与するコマンドから作成する。

/***
 * ファイル名: double_array.h
 *
 * DoubleArray検索部分。
 * AC法を実装するために必要な箇所を修正。
 * 修正箇所のみ記述。
 */
#include <string>  // 追加:

class DoubleArray {
private:
  unsigned node_size_;  // 追加: ノード数

public:
  DoubleArray(const char* filepath) {
    FILE* f = fopen(filepath, "rb");
    
    unsigned node_size;
    fread(&node_size, sizeof(unsigned), 1, f);
    base = new int [node_size];
    chck = new int [node_size];
    
    fread(base, sizeof(int), node_size, f);
    fread(chck, sizeof(int), node_size, f);
    fclose(f);
    
    node_size_ = node_size;  // 追加:
  }

  // 追加:
  unsigned node_size() const { return node_size_; } 

  // 追加: トライの全てのノードを探索する
  template<class Callback>
  void traverse_all_node(Callback& fn) const {
    std::string path;
    traverse_all_node(fn, path, 0);
  }

  // 追加: 入力文字列にマッチするノードを返す
  int search_node(const char* key) const {
    int node=0;
    CharStream in(key);

    for(int len=0;; len++) {
      if(in.peek()=='\0')
        return node;      // マッチノード

      int next_node = base[node] + in.read();
      if(chck[next_node] == node) 
        node = next_node;
      else
        return 0;         // ルートノード
    }
  }

  // 追加: 入力文字列にマッチするノードを返す
  //       入力文字列にマッチするトライ内のキーが存在しない場合は、ルートノードが返される
  int search_key_node(const char* key) const {
    int node=0;
    CharStream in(key);

    for(int len=0;; len++) {
      if(in.peek()=='\0')
        return chck[base[node]+'\0']==node ? node : 0;  // キーが存在するかどうか
      
      int next_node = base[node] + in.read();
      if(chck[next_node] == node) 
        node = next_node;
      else
        return 0;  // ルートノード
    }
  }

  // 追加: ノードの深さを求める
  unsigned depth(int node) const {
    unsigned dep=0;
    for(; node != 0; node = chck[node]) dep++;
    return dep;
  }

  // 追加:
  const int* base_ptr() const { return base; }
  const int* chck_ptr() const { return chck; }

private:
  // 追加: トライの全てのノードを探索する
  template<class Callback>
  void traverse_all_node(Callback& fn, std::string& path, int parent_node) const {
    for(unsigned child=1; child < 0x100; child++) { // 終端ノード(='\0')は飛ばす
      int child_node = base[parent_node] + child;
      if(parent_node == chck[child_node]) {
        path += static_cast<char>(child);
        fn(path.c_str(), child_node);
        traverse_all_node(fn, path, child_node);
        path.resize(path.size()-1);
      }
    }
  }

protected:  // 修正: 継承したクラスから参照可能なようにしておく
  int *base;  
  int *chck;  
/***
 * ファイル名: mkac.cc
 * コンパイル: g++ -O2 -o mkac mkac.cc
 *
 * mkdaが作成したインデックスに、AC法のための情報を付与する
 */
#include <iostream>
#include <cstdio>
#include "double_array.h"

class AC_Builder {
public:
  AC_Builder(const DoubleArray& da) 
    : da(da) {
    suffix_link =     new int[da.node_size()];
    key_suffix_link = new int[da.node_size()];
    node_depth      = new short[da.node_size()];
    
    // 初期化
    for(unsigned i=0; i < da.node_size(); i++) {
      suffix_link[i]     = 0;  // ルートノード
      key_suffix_link[i] = 0;  // ルートノード
      node_depth[i]      = 0;  // ルートノードの深さ
    }
  }

  ~AC_Builder() {
    delete [] suffix_link;
    delete [] key_suffix_link;
    delete [] node_depth;
  }

  // traverse_all_nodeメソッドのコールバック
  void operator()(const char* path, int node) {
    //if(node%10000 == 0)
    //  std::cout << node << "#" << path << std::endl;
    
    int suffix_node=0;
    do {
      if(*(++path)=='\0')
        break;
    } while(0 == (suffix_node=da.search_node(path)));
    
    int key_suffix_node = da.search_key_node(path);
    while(0 == key_suffix_node) {
      if(*(++path)=='\0')
        break;
      key_suffix_node=da.search_key_node(path);
    }

    suffix_link[node]     = suffix_node;
    key_suffix_link[node] = key_suffix_node;
    node_depth[node]      = da.depth(node);
  }

  void save(const char* filepath) const {
    FILE *f = fopen(filepath, "wb");
    unsigned size = da.node_size();
    fwrite(&size, sizeof(unsigned), 1, f);
    fwrite(da.base_ptr(), sizeof(int), size, f);
    fwrite(da.chck_ptr(), sizeof(int), size, f);
    fwrite(suffix_link, sizeof(int), size, f);
    fwrite(key_suffix_link, sizeof(int), size, f);
    fwrite(node_depth, sizeof(short), size, f);
    fclose(f);    
  }

private:
  const DoubleArray& da;
  int*   suffix_link;     // サフィックスのリンク
  int*   key_suffix_link; // 辞書サフィックスのリンク
  short* node_depth;      // ノードの深さ: 予め計算/保持しておくことで、検索時のコストを抑える
};

int main(int argc, char** argv) {
  if(argc != 3) {
    std::cerr << "Usage: mkac <ac-index> <da-index>" << std::endl;
    return 1;
  }
  
  DoubleArray da(argv[2]);
  AC_Builder  acb(da);

  da.traverse_all_node(acb);

  acb.save(argv[1]);

  return 0;
}

実行。

$ ./mkac ipa-ac.idx ipa.idx
$ ./mkac wiki-ac.idx wiki.idx
$ ls -lh *.idx
-rw-r--r-- 1 user user  24M 2010-07-03 19:28 ipa-ac.idx
-rw-r--r-- 1 user user  11M 2010-07-03 18:29 ipa.idx
-rw-r--r-- 1 user user 879M 2010-07-03 19:29 wiki-ac.idx
-rw-r--r-- 1 user user 391M 2010-07-03 18:26 wiki.idx


次は、AC法を用いたワードカウントコマンド(と、そのための検索クラス)

/***
 * ファイル名: double_array_ac.h
 *
 * DoubleArray検索のAC法版
 */
#ifndef DOUBLE_ARRAY_AC_H
#define DOUBLE_ARRAY_AC_H

#include <cstdio>
#include "char_stream.h"
#include "double_array.h"
#include <string>

#include <cassert>

// DoubleArray(AC法版)検索用のクラス
class DoubleArrayAC : public DoubleArray {
public:
  DoubleArrayAC(const char* filepath) 
    : DoubleArray(filepath) {
    suffix_link     = new int[node_size()];
    key_suffix_link = new int[node_size()];
    node_depth      = new short[node_size()];

    FILE *f = fopen(filepath, "rb");
    fseek(f, node_size()*sizeof(int)*2+sizeof(unsigned), SEEK_SET); // DoubleArrayクラスのデータ格納部分は飛ばす
    fread(suffix_link, sizeof(int), node_size(), f);
    fread(key_suffix_link, sizeof(int), node_size(), f);
    fread(node_depth, sizeof(short), node_size(), f);
    fclose(f);
  }    

  ~DoubleArrayAC() {
    delete [] suffix_link;
    delete [] key_suffix_link;
    delete [] node_depth;
  }

  // 入力文字列にマッチする全てのキーを検索する
  template<class Callback>
  void each_match(const char* text, Callback& fn) const {
    int node=0;
    CharStream in(text);
    for(;; in.read()) {
      // このノードに対応するキーがあるか
      int terminal = base[node]+'\0';
      if(chck[terminal] == node) 
        fn(in.rest()-node_depth[node], node_depth[node], -base[terminal]);

      // このノードの辞書サフィックスリンクを辿ってキーを探す
      for(int tmp=key_suffix_link[node]; tmp!=0; tmp=key_suffix_link[tmp])
        fn(in.rest()-node_depth[tmp], node_depth[tmp], -base[base[tmp]+'\0']);

      // 次のノードに遷移する
      for(;;) {
        if(in.peek()=='\0') 
          return;       // 検索終了

        int next = base[node] + in.peek();
        if(chck[next] == node) {
          node = next;  // 遷移
          break;
        } 

        // 現在のノードからの有効な遷移がない場合
        if(node==0)
          in.read();                 // ルートノードなら遷移文字を読み進める
        else
          node = suffix_link[node];  // サフィックスリンクを辿って、遷移をやり直す
      } 
    }    
  }

private:
  int*   suffix_link;     // サフィックスのリンク
  int*   key_suffix_link; // 辞書サフィックスのリンク
  short* node_depth;      // ノードの深さ: 予め計算/保持しておくことで、検索時のコストを抑える
};

#endif
/***
 * ファイル名: ac-count-word.cc
 * コンパイル: g++ -O2 -o ac-count-word ac-count-word.cc
 *
 * 入力テキスト内で辞書(= mkacコマンドが作成したDoubleArray(AC法版)インデックス)にマッチする文字列の数をカウントする
 */
#include <iostream>
#include <sys/time.h>
#include "double_array_ac.h"
#include "char_stream_vector.h"

inline double gettime(){
  timeval tv;
  gettimeofday(&tv,NULL);
  return static_cast<double>(tv.tv_sec)+static_cast<double>(tv.tv_usec)/1000000.0;
}

struct Counter {
  Counter() : count(0) {}
  void operator()(const char* key, unsigned length, int id) {
    //std::cout << id << "#" << std::string(key, length) << std::endl;
    count++;
  }
  unsigned count;
};

int main(int argc, char** argv) {
  if(argc != 3) {
    std::cerr << "Usage: ac-count-word <index> <text>" << std::endl;
    return 1;
  }
  
  DoubleArrayAC da(argv[1]);
  CharStreamVector csv(argv[2]);
  Counter fn;

  double beg_t = gettime();
  for(unsigned i=0; i < csv.size(); i++) {
    const char* line = csv[i].rest();
    da.each_match(line, fn);
  }

  std::cerr << "Matched word count: " << fn.count << std::endl;
  std::cerr << "Elapsed: " << gettime()-beg_t << " sec" << std::endl << std::endl;

  return 0;
}

実行。

$ ./ac-count-word wiki-ac.idx ~/kokoro
Matched word count: 202732
Elapsed: 0.01694 sec

$ ./ac-count-word ipa-ac.idx ~/kokoro
Matched word count: 269405
Elapsed: 0.0136041 sec

AC法版の方が遅い結果となった。

感想とか

検索時のループ数*2は、AC法版の方が少なく、通常版の(今回のデータでは)1/2程度であった。
ループ数は少なくなったのに処理時間が長くなっているのは、多分検索時に使用する配列が増えてメモリキャッシュの効率が(結構大幅に)下がったためではないかと思う。
加えて、AC法版の方が検索メソッドが行う処理が複雑になっているということも(どの程度かは分からないが)関係していると思う。
いずれにせよ、AC法を用いたトライは-説明を読んだ時には速そうに感じたけど-実際に実装してみると(少なくとも常には)通常のトライに対して速度的に優れているわけではない、ということは分かった*3
あと、データサイズが大きくなる、という欠点もある*4

*1:入力テキスト/辞書の文字列の長さ/検索マッチ数 に対して線形

*2:DoubleArray.common_prefix_searchの場合は、内部のforループの回数。DoubleArrayAC.each_matchの場合は、内部の三つのforループの総数

*3:僕の実装が悪い、という可能性は残っているけど

*4:AC法のための追加する情報(= 配列)を減らすことで、データサイズをある程度切り詰めることは可能。ただし、そうした場合、AC法を利用するそもそもの動機であった処理速度が遅くなる可能性がある。また、AC法はその性質上、トライの全ての要素をノードとして表現する必要があるので、DoubleArrayのTAIL配列のような(ノード数自体を少なくすることによる)データ圧縮方法が使えず、サイズ縮小には限界がある。