LOUDS++(5): trie

四回目まででLOUDS(LOUDS++)の実装方法は分かったので、今回は応用編。
LOUDSを用いてtrieを実装する(というかtrieを実装したソースコードを載せる)

作るもの

実際に作成するのは以下の二つのコマンド。

  • mklouds-trie: ソート済みテキストファイルから、LOUDSを用いたtrieを構築し、バイナリデータとして保存する
  • louds-trie: 上で作成したtrieのバイナリデータを読み込み、trieのキーの検索を行う

注意点

mklouds-trie

trieの構築部分。
まずは、bit-vector構築用のクラスから。

/***
 * ファイル名: bit_vector_builder.h
 * 
 * ここで扱っているbit-vectorの詳細に関しては、LOUDS++の四回目を参照(実装4)
 *   - http://d.hatena.ne.jp/sile/20100613/1276445650
 */
#ifndef BIT_VECTOR_BUILDER_H
#define BIT_VECTOR_BUILDER_H

#include <cstdio>
#include <cassert>
#include <cstring>

const unsigned WORD_SIZE=32;
const unsigned BLOCK_SIZE=64;
const unsigned SELECT_INDEX_INTERVAL=64;

// 切り上げ除算関数
unsigned ceil(unsigned dividend, unsigned divisor) {
  return (dividend + (divisor - 1)) / divisor;
}

class BitVectorBuilder {
public:
  // bit-vectorのサイズ(ビット数)は、初期化時に指定しておく
  // ※ 初期化時にサイズを指定するのは、実装を簡略化するため。必須ではない。
  BitVectorBuilder(unsigned size)
    : size(size), block_count(ceil(size,BLOCK_SIZE)),
      blocks(new unsigned[block_count*2]),
      pre_1bit_counts(new unsigned[block_count+1]),
      select_indices(new unsigned[(size/SELECT_INDEX_INTERVAL+1)+1]),
      cur(0), acc_1bit_count(0) 
  {
    memset(blocks, 0, sizeof(unsigned)*(block_count*2));
    select_indices[0] = 0;
  }
  
  ~BitVectorBuilder() {
    delete [] blocks;
    delete [] pre_1bit_counts;
    delete [] select_indices;
  }
  
  // ビットをビット列の末尾に追加する
  // この関数は、ちょうどsize回だけ呼ばれる必要がある
  void add(bool b_1bit) {
    // ブロックごとに、それ以前にある1ビットの数を保存する
    if(cur%BLOCK_SIZE == 0)
      pre_1bit_counts[cur/BLOCK_SIZE] = acc_1bit_count;

    if(b_1bit) {
      // ブロックにビットをセット
      // ブロックは、単にビット列をunsigned型にエンコード(パック?)したもの
      blocks[cur/WORD_SIZE] |= (1 << cur%WORD_SIZE);

      acc_1bit_count++;      

      // 1bit数がSELECT_INDEX_INTERVALの倍数になった場合は、select_indicesに現在のブロックを保存する
      if(acc_1bit_count%SELECT_INDEX_INTERVAL == 0)
        select_indices[acc_1bit_count/SELECT_INDEX_INTERVAL] = cur/BLOCK_SIZE;
    }
    cur++;
  }

  // bit-vectorをファイルに書き出す
  void write(FILE* f) {
    finish();
    
    unsigned si_size = (acc_1bit_count/SELECT_INDEX_INTERVAL+1)+1;
 
    // 各フィールドの値を、そのまま保存しているだけ
    fwrite(&size,           sizeof(unsigned), 1, f);
    fwrite(&acc_1bit_count, sizeof(unsigned), 1, f);
    fwrite(blocks,          sizeof(unsigned), block_count*2, f);
    fwrite(pre_1bit_counts, sizeof(unsigned), block_count+1, f);
    fwrite(select_indices,  sizeof(unsigned), si_size, f);
  }

private:
  void finish() {
    // size分だけaddメソッドが呼ばれたかどうかのチェック
    assert(cur==size);

    // 番兵値
    pre_1bit_counts[block_count] = acc_1bit_count;
    select_indices[acc_1bit_count/SELECT_INDEX_INTERVAL+1] = block_count;
  }

private:
  const unsigned size;        // ビット数
  const unsigned block_count; // ブロック数 = ceil(ビット数, BLOCK_SIZE)
  unsigned* blocks;           // ブロックの配列(ビット列をunsigned型にパックしたもの)
  unsigned* pre_1bit_counts;  // 各ブロックの開始位置より前にある1ビットの数
  unsigned* select_indices;   // SELECT_INDEX_INTERVALの倍数番目ごとの1ビットが属するブロックを保持する配列
  
  unsigned cur;            // 現在ポイント中のビット(addメソッド呼び出しごとに一増える)
  unsigned acc_1bit_count; // cur以前のビット列に含まれる1ビットの数
};

#endif


次は、mklouds-trieコマンドのソース。
ここで出てくる関数の骨格は、基本的には『ソート済みファイルをトライ木に見立ててレベル順(幅優先)探索を行う』のlevel_order関数と同様。(また、CharStreamVectorクラスやRange構造体、end_of_same_node関数に関しても左のリンク先を参照)

/***
 * ファイル名: mklouds-trie.cc
 * コンパイル: g++ -o mklouds-trie mklouds-trie.cc
 * 使用方法: mklouds-trie <キーセットを保持するソート済みファイル> <バイナリデータの出力先>
 */
#include "char_stream_vector.h"
#include "bit_vector_builder.h"
#include <cstdio>
#include <iostream>
#include <deque>
#include <string>

using namespace std;

// これから作成するtrieのノード数(= bit-vectorのビット数)をカウントする
// ※ ノード数を事前に取得しておく必要は必ずしもないが、この方がbit-vectorの構築は簡単になる
unsigned get_node_count(const char* filepath) {
  CharStreamVector csv(filepath);
  
  unsigned node_count=0;
  deque<Range> que;

  // trieをレベル順に探索して、ノード数をカウントする
  // ※ 深さ優先探索でも良い
  que.push_back(Range(0,csv.size()));
  while(!que.empty()) {
    node_count++;

    Range r = que.front();
    que.pop_front();

    if(csv[r.beg].eos())
      r.beg++;

    while(r.beg < r.end) {
      unsigned end = end_of_same_node(csv, r);
      que.push_back(Range(r.beg, end));
      r.beg = end;
    }
  }
  return node_count;
}

// trieを作成する
void build_louds_trie(CharStreamVector& csv, unsigned node_count, const char* filepath) {
  // bit-vector準備
  BitVectorBuilder r0(node_count);  // 子の有無判定用のbit-vector
  BitVectorBuilder r1(node_count);  // 兄弟の有無判定用のbit-vector
  BitVectorBuilder id(node_count);  // キーのIDを保持するためのbit-vector
  r1.add(1);  // スーパールート用の処理

  string vl;  // ノードの文字を保持する配列
  vl.reserve(node_count);
  vl += ' ';

  // レベル順探索: get_node_count関数やlevel_order関数のそれと同様
  deque<Range> que;
  que.push_back(Range(0,csv.size()));
  while(!que.empty()) {
    Range r = que.front();
    que.pop_front();

    // id: キーの末尾に達した場合
    if(csv[r.beg].eos()) {
      r.beg++;
      id.add(1);  // IDを付与
    } else {
      id.add(0);  // 末尾ではないならIDは付けない
    }

    // r0: 子ノードがいるなら1bit、いないなら0bit
    r0.add(r.beg!=r.end);
    
    unsigned child_count=0;
    for(; r.beg < r.end; child_count++) {
      // vl: ノードに対応する文字を取得
      vl += csv[r.beg].peek();

      unsigned end = end_of_same_node(csv, r);
      que.push_back(Range(r.beg, end));
      r.beg = end;
    }
    
    // r1: 兄弟ノード設定
    if(child_count != 0) {
      for(unsigned i=1; i < child_count; i++)
        r1.add(0);  // 左(?)に兄弟がいるノードには0ビットを設定
      r1.add(1);    // 1ビットで終端
    }
  }

  // ファイルに出力
  FILE* f;
  if((f=fopen(filepath,"wb"))==NULL)
    return;

  r0.write(f);
  r1.write(f);
  id.write(f);
  
  // assert(node_count == vl.size());
  fwrite(&node_count, sizeof(unsigned), 1, f);   // ノード数を出力 
  fwrite(vl.data(), sizeof(char), vl.size(), f); // 各ノードに対応する文字を出力
  fclose(f);
}

int main(int argc, char** argv) {
  CharStreamVector csv(argv[1]);
  build_louds_trie(csv, get_node_count(argv[1]), argv[2]);
  return 0;
}

ソースコードは短くはないが、louds-trieを構築するために行っていることは、レベル順に訪れる各ノードで子の有無や兄弟の数によって1ビットもしくは0ビットをbit-vectorに追加しているだけ、と結構単純。


実行例。
テスト用のデータには、MeCabのサイトで配布されているIPA辞書を用いた。

# テストデータ作成: IPA辞書の各項目の形態素名だけを取り出し、ソート/ユニークする
$ nkf -w mecab-ipadic-2.7.0-20070801/*.csv | cut -d, -f1 | LC_ALL=C sort | LC_ALL=C uniq > /tmp/ipadic.word

$ wc -l /tmp/ipadic.word
325871 /tmp/ipadic.word

$ ls -lh /tmp/ipadic.word
-rw-r--r-- 1 user user 3.8M 2010-06-19 02:01 /tmp/ipadic.word

# コンパイル
$ g++ -O2 -o mklouds-trie mklouds-trie.cc

# 実行
$ time ./mklouds-trie /tmp/ipadic.word /tmp/trie.dat
real	0m0.099s
user	0m0.092s
sys	0m0.008s

$ ls -lh /tmp/trie.dat
-rw-r--r-- 1 user user 1.7M 2010-06-19 02:04 /tmp/trie.dat  # 1/2くらいのサイズに

louds-trie

保存したtrieの読み込み/検索部分。
ここでもまずはbit-vector用のクラス定義から。

/***
 * ファイル名: bit_vector.h
 *
 * BitVectorBuilderが構築/保存したbit-vectorの読み込み、及びそのbit-vector上でのselect1/rank1操作を行う
 *
 * ここで扱っているbit-vectorの詳細に関しては、LOUDS++の四回目を参照(実装4)
 *   - http://d.hatena.ne.jp/sile/20100613/1276445650
 */
#ifndef BIT_VECTOR_H
#define BIT_VECTOR_H

#include <cstdio>

const unsigned WORD_SIZE=32;
const unsigned BLOCK_SIZE=64;
const unsigned SELECT_INDEX_INTERVAL=64;

/**
 * 補助関数三つ
 */
unsigned ceil(unsigned dividend, unsigned divisor) {
  return (dividend + (divisor - 1)) / divisor;
}

// unsigned型の数値中の1ビットの数を求める
inline unsigned log_count(unsigned n) {
  n = ((n & 0xAAAAAAAA) >> 1) + (n & 0x55555555);
  n = ((n & 0xCCCCCCCC) >> 2) + (n & 0x33333333);
  n = ((n >> 4) + n) & 0x0F0F0F0F;
  n = (n>>8) + n;
  return ((n>>16) + n) & 0x3F;
}

// unsigned型の数値の中で、最も大きい1ビットの位置を求める
// 引数nは0ではないことが前提
inline unsigned int_length(unsigned n) {
  // SBCLのinteger-length関数が、BSR命令を使っていたので、それを真似てみた
  // ただし、common lispのinteger-length関数は、1ビットの位置+1を返すので、この関数の返り値とは若干異なる
  asm("movl %0, %%eax;"
      "bsr  %%eax, %0;"
      : "=r"(n) : "0"(n) : "%eax");
  return n;
}

/**
 * bit-vectorクラス
 */
class BitVector {
public:
  BitVector(FILE* f) {
    // BitVectorBuilderが保存したデータを読み込む
    fread(&size,             sizeof(unsigned), 1, f);
    fread(&total_1bit_count, sizeof(unsigned), 1, f);
    unsigned block_count = ceil(size, BLOCK_SIZE);
    unsigned si_size = (total_1bit_count/SELECT_INDEX_INTERVAL+1)+1;

    blocks =         new unsigned[block_count*2];
    pre_1bit_counts= new unsigned[block_count+1];
    select_indices = new unsigned[si_size];
    fread(blocks,         sizeof(unsigned), block_count*2, f);
    fread(pre_1bit_counts,sizeof(unsigned), block_count+1, f);
    fread(select_indices, sizeof(unsigned), si_size, f);
  }

  ~BitVector() {
    delete [] blocks;
    delete [] pre_1bit_counts;
    delete [] select_indices;
  }

  // select1
  unsigned select1(unsigned nth) const {
    // nth番目の1ビットが属するブロックを二分探索で探す
    unsigned block_beg = select_indices[nth/64];
    unsigned block_end = select_indices[nth/64+1]+1;
    for(;;) {
      unsigned block_num = block_beg+(block_end-block_beg)/2;
      unsigned start     = pre_1bit_counts[block_num];
      unsigned end       = pre_1bit_counts[block_num+1]+1;
      
      if(start < nth && nth < end) 
        return block_num*BLOCK_SIZE + block_select1(nth-start, block_num);
      if(nth <= start)
        block_end = block_num;
      else
        block_beg = block_num;
    }
  }

  // rank1
  unsigned rank1(unsigned index) const {
    unsigned block_num = index/BLOCK_SIZE;
    unsigned offset    = index%BLOCK_SIZE;
    return pre_1bit_counts[block_num] + block_rank1(offset, block_num);
  }
  
  // 位置indexのビットが0かどうか
  bool is_0bit(unsigned index) const {
    return (blocks[index/WORD_SIZE]&(1<<(index%WORD_SIZE)))==0;
  }

private:
  unsigned block_low (unsigned block_num) const { return blocks[block_num*2]; }
  unsigned block_high(unsigned block_num) const { return blocks[block_num*2+1]; }
  
  // ブロックに対するselect1
  unsigned block_select1(unsigned nth, unsigned block_num) const {
    unsigned block = block_low(block_num);
    unsigned count = log_count(block);
    if(nth == count)
      return int_length(block);

    unsigned base=0;
    if(nth > count) { // nth番目の1ビットが、block_highにある場合
      nth -= count;
      base = 32;
      block= block_high(block_num);

      count = log_count(block);
      if(nth == count)
	return base + int_length(block);
    }

    // nth番目の1ビットの位置を二分探索で探す
    unsigned beg = 0;
    unsigned end = 32;
    for(;;) {
      unsigned m = beg+(end-beg)/2;
      unsigned b = block & ((1<<m)-1);
      unsigned c = log_count(b);

      if(nth == c)
	return base + int_length(b);
      if(nth < c)
	end = m;
      else 
	beg = m;
    }
  }

  // ブロックに対するrank1
  unsigned block_rank1(unsigned offset, unsigned block_num) const {
    if(offset < 32)
      return log_count(block_low(block_num) & ((2 << offset)-1));

    return log_count(block_low (block_num)) +
      log_count(block_high(block_num) & ((2 << (offset-32))-1));
  }

private:
  unsigned size;
  unsigned total_1bit_count;
  unsigned *blocks;
  unsigned *pre_1bit_counts;
  unsigned *select_indices;
};

#endif


上のBitVectorクラスを用いたtrieの実装。

/***
 * ファイル名: louds_trie.h
 *
 * LOUDSによるtrieの実装
 * build_louds_trie関数が保存したtrieデータを読み込む
 *
 * is_leaf/next_sibling/prev_sibling/first_child/last_child/parentメソッドに関しては、以下の記事を参照
 *  - http://d.hatena.ne.jp/sile/20100531/1275324255
 */
#ifndef LOUDS_TRIE_H
#define LOUDS_TRIE_H

#include "bit_vector.h"
#include <cstdio>

// 単なる文字列クラス
// mklouds-trieコマンドが作成したファイルからの読み込みを簡単にするためだけに作成
class NodeChars {
public:
  NodeChars(FILE *f) {
    fread(&size, sizeof(unsigned), 1, f);
    chars = new char[size];
    fread(chars, sizeof(char), size, f);
  }
  ~NodeChars() {
    delete [] chars;
  }
  
  char operator[](unsigned index) const { return chars[index]; }
  
private:
  unsigned size;
  char*    chars;
};

// trie
class LoudsTrie {
public:
  static const unsigned ROOT_NODE=0;
  static const unsigned NOT_FOUND=0xFFFFFFFF;
  
  LoudsTrie(FILE *f) : r0(f), r1(f), id(f), nc(f) {}
  
  // 引数のキーに紐付くIDを検索する
  // キーが存在しない場合は、LoudsTrie::NOT_FOUNDを返す
  unsigned find(const char* key) const {
    unsigned node = first_child(ROOT_NODE);
    while(node != NOT_FOUND) {
      if(node_char(node) == *key) {
        key++;
        if(*key == '\0') 
          return key_id(node);
        
        node = first_child(node);
      } else {
        node = next_sibling(node);
      }
    }
    return NOT_FOUND;
  }

private:
  bool is_leaf(unsigned node)   const { return r0.is_0bit(node); }
  char node_char(unsigned node) const { return nc[node]; }
  unsigned next_sibling(unsigned node) const { return r1.is_0bit(node)   ? node+1 : NOT_FOUND; }
  unsigned prev_sibling(unsigned node) const { return r1.is_0bit(node-1) ? node-1 : NOT_FOUND; }
  unsigned first_child(unsigned node)const { return is_leaf(node) ? NOT_FOUND : r1.select1(r0.rank1(node))+1; }
  unsigned last_child(unsigned node) const { return is_leaf(node) ? NOT_FOUND : r1.select1(r0.rank1(node)+1); }
  unsigned parent(unsigned node)     const { return node==ROOT_NODE ? NOT_FOUND : r0.select1(r1.rank1(node-1)); }
  unsigned key_id(unsigned node)     const { return id.is_0bit(node)? NOT_FOUND : id.rank1(node); }

private:
  const BitVector r0;
  const BitVector r1;
  const BitVector id;
  const NodeChars nc;
};

#endif


メイン関数。
短い。

/***
 * ファイル名: louds-trie.cc
 * コンパイル: g++ -o louds-trie louds-trie.cc
 * 使用方法: louds-trie <mklouds-trieが作成したバイナリデータ> 
 *             - 標準入力からキーを読み込み検索を行う
 *             - キーがtrie内に存在しない場合は、そのキーを標準出力に出力する
 */
#include <iostream>
#include <string>
using namespace std;

#include "louds_trie.h"

int main(int argc, char** argv) {
  // trieの読み込み
  LoudsTrie lt(fopen(argv[1],"rb"));
  
  int cnt=0;
  string s;
  while(getline(cin,s)) {
    cnt++;
    if(cnt%100000==0)
      cerr << cnt << endl;
    
    // キーの検索
    if(lt.find(s.c_str())==lt.NOT_FOUND)
      cout << s << endl;
  }
  return 0;
}


実行例。

# コンパイル
$ g++ -O2 -o louds-trie louds-trie.cc

# 実行
$ time ./louds-trie /tmp/trie.dat < /tmp/ipadic.word 
100000  # 以下三行は標準エラー出力に出力される進捗表示
200000
300000

real	0m0.509s
user	0m0.504s
sys	0m0.008s

$ echo aiueo | ./louds-trie  /tmp/trie.dat 2> /dev/null
aiuet  # trie内に存在しないので標準出力に出力される


以上。
今回作成したものを使った性能評価的なものはまた今度。