LOUDS++(6): trie改良試作(TAIL配列版)

前回に作成したLOUDS(LOUDS++)を用いたtrieの改良を試みる。

改良案

前回のtrie実装は、まだまだ全然最適化されていないので、改良すべき(or できるであろう)箇所は結構沢山ある。
案として、例えば...

  • bit-vectorの実装方法 ...
    • 他の実装方法の方が効率的(使用サイズが少ない or 高速)では?
    • 現在の実装方法にしても、ここの関数/メソッドの実装は改善できるのでは?
  • trie ...
    • LoudsTrie.idフィールドにはselect_indices配列は不要? (キーのIDから、キー文字列を取得したい場合は必要だけど)
    • 今は終端を専用のビット列(LoudsTrie.id)で管理しているが、代わりにヌル文字ノードを使うようにした方が、良いのでは? (少なくとも実装は簡略化される)
  • キー検索方法 ...
    • 兄弟ノードが多い場合に、先頭から順にノード文字を比較していくの非効率 (二分探索、その他を場合に応じて使い分けてはどうか)
    • 上位nノードまでのselect1/rank1の結果はメモ化しておいても良いのでは?

などなど。

実用の場面で使えるものを作ろうとするなら、こういった(云わば実装の細部に関する)改良も必要だと思う。
ただ、今回は、このような実装の細部に関わる修正は置いておいて、trieの構造自体の変更を試してみようと思う。(とはいってもそう大したことをするわけでもないが)

TAIL配列

では、実際にはどうするかというと、TAIL配列*1を使う。

大雑把に云えば、trieの各キーで共有がある先頭部分の文字列は普通にノードとして保持し、残りの各キーにユニークな末尾部分の最初の一文字以外は(そのためのノードを設けるのではなく)TAILという配列にまとめて入れてしまおう、といったような方法。


例:
キーセット。

関西外国語短大
関西大学
関西学院大学西鉄鋼新聞
関西鉄鋼短大

キーセットに対応する木構造

TAIL配列を用いた場合。※ 一番下にあるのがTAIL配列

さらにTAIL配列の中で包含する文字列をまとめた場合。※ 上の図に比べてTAIL配列が若干短くなっている

図を見てもらえば分かると通り、TAIL配列使ってキーのユニークな末尾部分を管理するようにすると、trieのノード数を減らすことが可能となる(上の例では20個から9個に減っている)
その分はTAIL配列に移動しているので、実質的な総数は変わらないが、(前回のLOUDSによるtrie実装では)ノード一つ当たりには最大14bit*2の領域が必要なのに対し、TAIL配列では(要素は単なるchar型なので)一文字当たり8bit、と大きければ半分近く節約出きることになる。
また、最後の図にあるようにTAIL配列中で包含関係にある末尾文字列同士をまとめることで、さらに使用領域を削ることができる。
加えて、TAIL配列を使用すれば、前回のtrieでキー検索時にユニークな末尾部分を辿る際にも呼び出していた(ある意味無駄な)select1/rank1を省くことも可能となる(strcmp関数の一回の呼び出しに置換可能)
select1/rank1呼び出しのコストが低くはないことを考えると、これは結構重要なことのように思える。


TAIL配列の概要と、そのメリット(と思われること)は大体以上。

デメリット

一応思いつくデメリットも書いておく。

  • TAIL配列内の各末尾文字列を区切るためにヌル文字を使用する必要がある
    • 各キーごとに8bit余計に必要になる。 ※ ただし、末尾文字列を包含する他の文字列がある場合は不要
  • TAIL配列内の文字列の開始位置を示すインデックスが必要
    • インデックスのサイズを4byteとすれば、キー数*4byteの領域が必ず必要となる
  • ユニークな末尾部分が僅かにしかないようなキーセットの場合は、サイズ的/処理的なオーバヘッドの方が大きくなる(多分)
  • trie構築が若干面倒かつ非効率になる

上のような理由から、キー数が少ない場合やユニークな末尾文字列が僅かにしかないキーセットの場合は、TAIL配列を使用するオーバーヘッドの大きくなる可能性もあると思う。
ただし、LOUDSが重宝するような大量のキーを扱うケースでは、これらのデメリットよりもTAIL配列を使うことで節約可能な領域およびselect1/rank1呼び出しによるメリットの方が大きいように思うが、その辺りは全然調査していないので現時点では'思う'としか云えない。

実装

残りは、TAIL配列を用いたtrieの実装。
大半の部分は前回と同じ。
変更がないソースファイルは割愛することにする。


まず、構築部分。

/***
 * ファイル名: mklouds-trie-tail.cc
 * コンパイル: g++ -o mklouds-trie-tail mklouds-trie-tail.cc
 * 使用方法: mklouds-trie-tail <キーセットを保持するソート済みファイル> <バイナリデータの出力先>
 */
#include "char_stream_vector.h"
#include "bit_vector_builder.h"
#include <cstdio>
#include <iostream>
#include <deque>
#include <string>
#include <vector>
#include "shrink_tail.h"   // TAIL配列の圧縮を行うクラスが定義されたソースファイル(後述)

using namespace std;

// これから作成するtrieのノード数(= bit-vectorのビット数)をカウントする
// ※ ノード数を事前に取得しておく必要は必ずしもないが、この方がbit-vectorの構築は簡単になる
unsigned get_node_count(const char* filepath) {
  CharStreamVector csv(filepath);
  
  unsigned node_count=0;
  deque<Range> que;

  // trieをレベル順に探索して、ノード数をカウントする
  // ※ 深さ優先探索でも良い
  que.push_back(Range(0,csv.size()));
  while(!que.empty()) {
    node_count++;

    Range r = que.front();
    que.pop_front();

    // 追加部分:
    // TAIL配列に格納する末尾文字列は、ノードとしてカウントしない
    if((r.end-r.beg)==1)  // ← これ以後の子孫には分岐がない
      continue;
    
    if(csv[r.beg].eos())
      r.beg++;

    while(r.beg < r.end) {
      unsigned end = end_of_same_node(csv, r);
      que.push_back(Range(r.beg, end));
      r.beg = end;
    }
  }
  return node_count;
}

// trie(TAIL配列版)を作成する
void build_louds_trie_tail(CharStreamVector& csv, unsigned node_count, const char* filepath) {
  // bit-vector準備
  BitVectorBuilder r0(node_count);      // 子の有無判定用のbit-vector
  BitVectorBuilder r1(node_count);      // 兄弟の有無判定用のbit-vector
  BitVectorBuilder id(node_count);      // キーのIDを保持するためのbit-vector
  BitVectorBuilder tail_id(node_count); // TAIL配列用のID。tind配列の添字として使われる (追加部分)
  r1.add(1);  // スーパールート用の処理

  string vl;  // ノードの文字を保持する配列
  vl.reserve(node_count);
  vl += ' ';

  // 追加部分:
  string tail;           // TAIL配列
  vector<unsigned> tind; // TAIL配列中の末尾文字列の開始位置を保持する配列
  tind.push_back(0);     // 0番目は使われないので、0を挿入しておく

  // レベル順探索: get_node_count関数やlevel_order関数のそれと同様
  deque<Range> que;
  que.push_back(Range(0,csv.size()));
  while(!que.empty()) {
    Range r = que.front();
    que.pop_front();

    // 追加部分:
    // ユニークな末尾文字列に対する処理
    if((r.end-r.beg) == 1) {
      id.add(1);              // ID付与
      r0.add(0);              // 子はいないとしておく (どっちでも良い?)
      if(csv[r.beg].eos()) {  
        tail_id.add(0);       // ヌル文字(= キー終端)の場合は特別扱いして、TAIL配列に追加しない
      } else {
        tail_id.add(1);               // TAIL配列用のID付与 (ヌル文字を特別扱いしているので、通常のIDと分ける必要がある)
        tind.push_back(tail.size());  // tind配列に、このキーの末尾文字列のTAIL配列内での開始位置を格納する
        tail += csv[r.beg].rest();    // TAIL配列に末尾文字列追加
        tail += '\0';                 // ヌル終端
      }
      continue;
    }

    // id: キーの末尾に達した場合
    if(csv[r.beg].eos()) {
      r.beg++;
      id.add(1);  // IDを付与
    } else {
      id.add(0);  // 末尾ではないならIDは付けない
    }

    // r0: 子ノードがいるなら1bit、いないなら0bit
    r0.add(r.beg!=r.end);
  
    // 追加部分:
    tail_id.add(0);

    unsigned child_count=0;
    for(; r.beg < r.end; child_count++) {
      // vl: ノードに対応する文字を取得
      vl += csv[r.beg].peek();
      
      unsigned end = end_of_same_node(csv, r);
      que.push_back(Range(r.beg, end));
      r.beg = end;
    }
    
    // r1: 兄弟ノード設定
    if(child_count != 0) {
      for(unsigned i=1; i < child_count; i++)
        r1.add(0);  // 左(?)に兄弟がいるノードには0ビットを設定
      r1.add(1);    // 1ビットで終端
    }
  }

  // ファイルに出力
  FILE* f;
  if((f=fopen(filepath,"wb"))==NULL)
    return;

  r0.write(f);
  r1.write(f);
  id.write(f);
  tail_id.write(f);  // 追加部分
  
  // assert(node_count == vl.size());
  fwrite(&node_count, sizeof(unsigned), 1, f);   // ノード数を出力 
  fwrite(vl.data(), sizeof(char), vl.size(), f); // 各ノードに対応する文字を出力

  // 追加部分: 
  ShrinkTail(tail, tind).shrink();      // TAIL配列の圧縮(包含関係にある末尾文字列の併合)を行う

  unsigned len=tind.size();
  fwrite(&len, sizeof(unsigned), 1, f);
  fwrite(tind.data(), sizeof(unsigned), len, f); // XXX: vector.data()はgcc拡張
  
  len = tail.size();
  fwrite(&len, sizeof(unsigned), 1, f);
  fwrite(tail.data(), sizeof(char), len, f);

  fclose(f);
}

int main(int argc, char** argv) {
  CharStreamVector csv(argv[1]);
  build_louds_trie_tail(csv, get_node_count(argv[1]), argv[2]);
  return 0;
}


TAIL配列の圧縮(包含関係にある末尾文字列の併合)を行っているクラスの定義。
これは、Doarで使用しているものをほぼそのまま持ってきている。
一応次の記事も参照: 『DoubleArray(3-1): TAIL配列圧縮』

/***
 * ファイル名: shrink_tail.h
 */
#ifndef SHRINK_TAIL_H
#define SHRINK_TAIL_H

#include <vector>
#include <algorithm>

typedef std::vector<unsigned> TindList;
typedef std::string           Tail;

class ShrinkTail {
public:
  ShrinkTail(Tail& tail_, TindList& tind_) 
    : tail(tail_), tind(tind_) {}
  
  void shrink() {
    std::vector<ShrinkRecord> terminal_indices;
    terminal_indices.reserve(tind.size());
    
    for(unsigned i=0; i < tind.size(); i++)
      terminal_indices.push_back(ShrinkRecord(i,tail.data()+tind[i]));
    
    std::sort(terminal_indices.begin(), terminal_indices.end(), tail_gt);
    
    Tail new_tail;
    new_tail.reserve(tail.size()/2);
    new_tail += '\0';
    
    for(unsigned i=0; i < terminal_indices.size(); i++) {
      const ShrinkRecord& p = terminal_indices[i];
      
      unsigned tail_idx = new_tail.size();
      if(i>0 && can_share(terminal_indices[i-1], p)) {
        tail_idx -= p.tail_len+1; // +1 is necessary for last '\0' character
      } else {
        new_tail += p.tail;
        new_tail += '\0';
      }
      tind[p.tind_idx] = tail_idx;
    }
    tail = new_tail;
  }
  
private:
  struct ShrinkRecord {
    ShrinkRecord(unsigned i,const char* t) 
      : tind_idx(i),tail(t),tail_len(static_cast<int>(strlen(t))) {}
    unsigned      tind_idx;
    const char* tail;
    int tail_len;
  };
  
  // Is lft including rgt ?
  bool can_share(const ShrinkRecord& lft, const ShrinkRecord& rgt) const {
    const char* lp = lft.tail;
    const char* rp = rgt.tail;
    
    for(int li=lft.tail_len-1, ri=rgt.tail_len-1;; li--, ri--) {
      if(ri < 0)                return true;  // NOTE: ri must be checked before li.
      else if(li < 0)           return false;
      else if(lp[li] != rp[ri]) return false;
    }
  }
  
  // Is lft greater than rgt ?
  static bool tail_gt (const ShrinkRecord& lft, const ShrinkRecord& rgt) {
    const char* lp = lft.tail;
    const char* rp = rgt.tail;
    
    for(int li=lft.tail_len-1, ri=rgt.tail_len-1;; li--, ri--) {
      if(li < 0)               return false;
      else if(ri < 0)          return true;
      else if(lp[li] > rp[ri]) return true;
      else if(lp[li] < rp[ri]) return false;
    }
  }
  
private:
  Tail     &tail;
  TindList &tind;
};

#endif


次は読み込み/検索部分。

/***
 * ファイル名: louds_trie_tail.h
 *
 * LOUDSによるtrieの実装(TAIL配列版)
 * build_louds_trie_tail関数が保存したtrieデータを読み込む
 *
 * is_leaf/next_sibling/prev_sibling/first_child/last_child/parentメソッドに関しては、以下の記事を参照
 *  - http://d.hatena.ne.jp/sile/20100531/1275324255
 */
#ifndef LOUDS_TRIE_TAIL_H
#define LOUDS_TRIE_TAIL_H

#include "bit_vector.h"
#include <cstdio>

// 単なる文字列クラス
// mklouds-trie-trieコマンドが作成したファイルからの読み込みを簡単にするためだけに作成
class NodeChars {
public:
  NodeChars(FILE *f) {
    fread(&size, sizeof(unsigned), 1, f);
    chars = new char[size];
    fread(chars, sizeof(char), size, f);
  }
  ~NodeChars() {
    delete [] chars;
  }
  
  char operator[](unsigned index) const { return chars[index]; }
  
private:
  unsigned size;
  char*    chars;
};


// unsigned配列クラス
//  TAIL配列内の末尾文字列の開始位置を保持する
//  NodeCharsクラスと同様にファイルからの読み込みを簡単にするためだけに作成
class TailIndex {
public:
  TailIndex(FILE *f) {
    fread(&size, sizeof(unsigned), 1, f);
    indices = new unsigned[size];
    fread(indices, sizeof(unsigned), size, f);
  }

  ~TailIndex() {
    delete [] indices;
  }

  unsigned operator[](unsigned index) const { return indices[index]; }

private:
  unsigned  size;
  unsigned* indices;
};


// 単なる文字列クラス
//  NodeCharsクラスと同様にファイルからの読み込みを簡単にするためだけに作成
//  NodeCharsとまとめてしまってもいいかも
class Tail {
public:
  Tail(FILE *f) {
    fread(&size, sizeof(unsigned), 1, f);
    tail = new char[size];
    fread(tail, sizeof(char), size, f);
  }

  ~Tail() {
    delete [] tail;
  }

  const char* data() const { return tail; }
  
private:
  unsigned size;
  char* tail;
};

// trie(TAIL配列版)
class LoudsTrieTail {
public:
  static const unsigned ROOT_NODE=0;
  static const unsigned NOT_FOUND=0xFFFFFFFF;
  
  LoudsTrieTail(FILE *f) : r0(f), r1(f), id(f), ti(f), nc(f), tind(f), tail(f) {}

  // 引数のキーに紐付くIDを検索する
  // キーが存在しない場合は、LoudsTrie::NOT_FOUNDを返す 
  unsigned find(const char* key) const {
    unsigned node = first_child(ROOT_NODE);
    while(node != NOT_FOUND) {
      if(node_char(node) == *key) {
        key++;
        
        // キー終端: TAIL配列内にエントリがなく、かつノードにIDが紐付いているなら検索成功
        if(*key=='\0')
          return tail_id(node)==NOT_FOUND ? key_id(node) : NOT_FOUND;

        unsigned tid = tail_id(node);
        if(tid!=NOT_FOUND)  // TAIL配列内にエントリがある場合は、キーの残りの部分とstrcmpで比較
          return (strcmp(key, tail.data()+tind[tid])==0) ? key_id(node) : NOT_FOUND;
        
        node = first_child(node);
      } else {
        node = next_sibling(node);
      }
    }
    return NOT_FOUND;
  }

private:
  bool is_leaf(unsigned node)   const { return r0.is_0bit(node); }
  char node_char(unsigned node) const { return nc[node]; }
  unsigned next_sibling(unsigned node) const { return r1.is_0bit(node)   ? node+1 : NOT_FOUND; }
  unsigned prev_sibling(unsigned node) const { return r1.is_0bit(node-1) ? node-1 : NOT_FOUND; }
  unsigned first_child(unsigned node)const { return is_leaf(node) ? NOT_FOUND : r1.select1(r0.rank1(node))+1; }
  unsigned last_child(unsigned node) const { return is_leaf(node) ? NOT_FOUND : r1.select1(r0.rank1(node)+1); }
  unsigned parent(unsigned node)     const { return node==ROOT_NODE ? NOT_FOUND : r0.select1(r1.rank1(node-1)); }
  unsigned key_id(unsigned node)     const { return id.is_0bit(node)? NOT_FOUND : id.rank1(node); }

  // 追加
  unsigned tail_id(unsigned node)     const { return ti.is_0bit(node)? NOT_FOUND : ti.rank1(node); } 

private:
  const BitVector r0;
  const BitVector r1;
  const BitVector id;
  const BitVector ti;   // 追加: tind用の添字(ID)
  const NodeChars nc;
  const TailIndex tind; // 追加: TAIL配列中の末尾文字列の開始位置. tind[key_id(node)]
  const Tail      tail; // 追加: TAIL配列
};

#endif


メイン関数。
include文とクラス名が変わるくらい。

/***
 * ファイル名: louds-trie-tail.cc
 * コンパイル: g++ -o louds-trie-tail louds-trie-tail.cc
 * 使用方法: louds-trie-tail <mklouds-trie-tailが作成したバイナリデータ> 
 *             - 標準入力からキーを読み込み検索を行う
 *             - キーがtrie内に存在しない場合は、そのキーを標準出力に出力する
 */
#include <iostream>
#include <string>
using namespace std;

#include "louds_trie_tail.h"

int main(int argc, char** argv) {
  // trieの読み込み
  LoudsTrieTail lt(fopen(argv[1],"rb"));
  
  int cnt=0;
  string s;
  while(getline(cin,s)) {
    cnt++;
    if(cnt%100000==0)
      cerr << cnt << endl;
    
    // キーの検索
    if(lt.find(s.c_str())==lt.NOT_FOUND)
      cout << s << endl;
  }
  return 0;
}


実行例。
データは前回作成したものを用いる。

# 構築部コンパイル
$ g++ -O2 -o mklouds-trie-tail mklouds-trie-tail.cc

# 構築部実行
$ time ./mklouds-trie-tail /tmp/ipadic.word /tmp/trie-tail.dat
real	0m0.144s # 構築はTAIL配列無し版よりも少し遅い
user	0m0.132s
sys	0m0.012s

$ ls -lh /tmp/trie-trie.dat
-rw-r--r-- 1 user user 1.7M 2010-06-19 02:04 /tmp/trie-trie.dat  # TAIL配列なし版とサイズはあまり変わらず

# 検索部コンパイル
$ g++ -O2 -o louds-trie-tail louds-trie-tail.cc

# 実行
$ time ./louds-trie-tail /tmp/trie-tail.dat < /tmp/ipadic.word 
100000 # 以下三行は標準エラー出力に出力される進捗表示
200000
300000

real	0m0.476s  # 検索は少し速くなっている
user	0m0.476s
sys	0m0.000s

以上。
一応動くものが完成。

IPA辞書くらいのデータ数(or IPA辞書のようなキーセット)では、TAIL配列の有無でそれほど結果が変わることはなかった。

データ数を増やしての比較はまた今度。

*1:trieにTAIL配列を使うというアイデア自体は、DoubleArrayに関する右の論文で知った: Jun-ichi Aoe, katushi Morimoto and Takashi Satou,『An Efficient Implementation of Trie Structures』, Software Practice & Experience, Vol.22, No.9, pp.695-721, 1992

*2:ノードの文字用に8bit + 三つのbit-vector用に3〜6bit