LOUDS++(8): trie - 検索速度向上

試したいことが出てきたので、少し延長して八回目。

目的

今回の目的は六回目で作成したtrieをベースとし、検索速度を向上させること。

結果

最初に結果から載せる。
入力データや計測方法等は七回目のそれらに準拠。

データサイズ検索所要時間(秒)
louds-trie (五回目)74MB32.1372s
louds-trie-tail (六回目)52MB17.0782s
louds-trie-less-node52MB13.4044s
louds-trie-less-node-opt52MB11.6995s
doar94MB3.19229s
tx67MB51.556s
darts416MB3.35106s
louds-trie-less-nodeおよびlouds-trie-node-optが今回の追加分。
(今回の入力データでは)ベースとしたlouds-trie-tailに比べて、検索に要する時間が2/3程度に短縮されている。

louds-trie-less-node

上で書いた通りlouds-trie-less-nodeは、六回目に作成したlouds-trie-tailをベースとしている。
louds-trie-tailでは、各キーにユニークな末尾文字列を(ノードとして保持する代わりに)TAIL配列に格納していた。
TAIL配列を用いると、末端に続く無分岐の(= 兄弟がいない)ノードがtrie木上から除去されることになるので、総ノード数を減らすことが可能となる。
で、ノード数が減ると、コストが高いselect1(これはノードの子ノードを取得する際に必要)の呼び出しも減るので、結果検索速度が速くなる*1
louds-trie-less-nodeは、louds-trie-tailでのTAIL配列利用を一般化しており、末端に続くかどうかに関係なく全ての無分岐ノード(に対応する文字)をTAIL配列*2に格納するようになっている。
こうすることで、louds-trie-tailよりもさらにtrieに必要なノード数を抑えることが可能となり、結果的に検索速度も向上することが期待できる*3


参考までに、louds-trie/louds-trie-tail/louds-trie-less-nodeのそれぞれで、wiki.titleファイル*4に含まれるキーセットをtrieとして表現する場合に必要なノード数を載せておく。

$ wc -l wiki.title
5340378 wiki.title  # 約530万キー

# 必要なノード数
louds-trie:          45806866  # 約4580万ノード
louds-trie-tail:     11485872  # 約1150万ノード
louds-trie-less-node: 7356751  # 約 740万ノード


louds-trie-less-nodeの実装。
louds-trie-tailから変更がある箇所のみ。

/***
 * 構築用ソースコード
 *
 * ファイル名: mklouds-trie-less-node.cc
 * コンパイル: g++ -o mklouds-trie-less-node mklouds-trie-less-node.cc
 * 使用方法: mklouds-trie-less-node <キーセットを保持するソート済みファイル> <バイナリデータの出力先>
 */
#include "char_stream_vector.h"
#include "bit_vector_builder.h"
#include <cstdio>
#include <iostream>
#include <deque>
#include <string>
#include "shrink_tail.h"

using namespace std;

// これから作成するtrieのノード数(= bit-vectorのビット数)をカウントする
unsigned get_node_count(const char* filepath) {
  CharStreamVector csv(filepath);
  
  unsigned node_count=0;
  deque<Range> que;

  // trieをレベル順に探索して、ノード数をカウントする
  que.push_back(Range(0,csv.size()));
  for(; !que.empty(); node_count++) {
    Range r = que.front();
    que.pop_front();

    if(csv[r.beg].eos())
      if(++r.beg == r.end)
        continue;           // 末端に到達
    
    if(r.end-r.beg == 1)
      continue;             // 末端までに分岐がない
    
    // 分岐がない部分は飛ばす
    // これはlouds-trie-less-nodeに特有
    unsigned end;
    while(r.end == (end=end_of_same_node(csv, r)))
      if(csv[r.beg].eos())   // キー終端を検出しても終了
        break;

    // 子ノードをキューに追加
    for(;;) {
      que.push_back(Range(r.beg,end));
      r.beg = end;
      if(r.beg == r.end)
        break;
      end = end_of_same_node(csv,r);
    }
  }

  return node_count;
}

// trieを作成する
void build_louds_trie(CharStreamVector& csv, unsigned node_count, const char* filepath) {
  BitVectorBuilder r0(node_count);            // 子の有無判定用のbit-vector
  BitVectorBuilder r1(node_count);            // 兄弟の有無判定用のbit-vector
  BitVectorBuilder eos_key_num(node_count);   // キーの末尾に対応するノードの番号
  BitVectorBuilder tail_key_num(node_count);  // キーの末尾文字列に続くノードの番号
  BitVectorBuilder tail_id(node_count);       // TAIL配列用のID。tind配列の添字として使われる

  string vl;  // ノードの文字を保持する配列
  vl.reserve(node_count);

  Tail tail;
  TindList tind;

  // スーパールート用の処理
  r1.add(1);  
  vl += ' ';
  tind.push_back(0); 

  // レベル順探索: get_node_count関数のそれと同様
  deque<Range> que;  
  que.push_back(Range(0,csv.size()));
  while(!que.empty()) {
    Range r = que.front();
    que.pop_front();

    // キーの末尾に達したかどうか
    if(csv[r.beg].eos()) {
      eos_key_num.add(1);     // キー末尾
     if(++r.beg == r.end) {
        r0.add(0);            // 子はいない
        tail_id.add(0);       // TAIL配列は使わない
        tail_key_num.add(0);  // TAIL配列は使わない
        continue;
      }
    } else {
      eos_key_num.add(0);     // キー末尾ではない
    }
    
    // キーの末尾まで、無分岐ノードが続くかどうか
    if(r.end-r.beg == 1) {
      // 無分岐ノードが続く
      r0.add(0);            // 子はいない  # 残りはTAIL配列内に移してしまうので
      tail_id.add(1);       // TAIL配列を使う
      tail_key_num.add(1);  // このキーはTAIL配列内で終端する
      
      // TAIL配列に文字列追加
      tind.push_back(tail.size());  
      tail += csv[r.beg].rest();
      tail += '\0';
      continue;
    } else {
      r0.add(1);            // 子がいる
      tail_key_num.add(0);  // このキーはTAIL配列内では終端しない
    }

    string one_way_chars;  // 無分岐ノードに対応する文字列
    unsigned end;
    while(r.end == (end=end_of_same_node(csv,r))) { 
      if(csv[r.beg].eos()) 
        break;
      one_way_chars += csv[r.beg].prev();
    }

    if(one_way_chars.empty()) {
      // 無分岐ノード無し
      tail_id.add(0);      // TAIL配列は使わない
    } else {
      tail_id.add(1);      // TAIL配列を使う

      // TAIL配列に文字列追加
      tind.push_back(tail.size());
      tail += one_way_chars;
      tail += '\0';
    }

    for(;;) {
      vl += csv[r.beg].prev();
      que.push_back(Range(r.beg,end));
      
      r.beg = end;
      if(r.beg == r.end)
        break;
      r1.add(0);  // next_siblingありノード
      end = end_of_same_node(csv,r);
    }
    r1.add(1);    // next_siblingなしノード
  }
  
  // ファイルに出力
  FILE* f;
  if((f=fopen(filepath,"wb"))==NULL)
    return;

  r0.write(f);
  r1.write(f);
  eos_key_num.write(f);
  tail_key_num.write(f);
  tail_id.write(f);
  
  // assert(node_count == vl.size());
  fwrite(&node_count, sizeof(unsigned), 1, f);   // ノード数を出力 
  fwrite(vl.data(), sizeof(char), vl.size(), f); // 各ノードに対応する文字を出力

  // 
  ShrinkTail(tail, tind).shrink();

  unsigned len=tind.size();
  fwrite(&len, sizeof(unsigned), 1, f);
  fwrite(tind.data(), sizeof(unsigned), len, f); 
  
  len = tail.size();
  fwrite(&len, sizeof(unsigned), 1, f);
  fwrite(tail.data(), sizeof(char), len, f);

  fclose(f);
}

int main(int argc, char** argv) {
  CharStreamVector csv(argv[1]);
  build_louds_trie(csv, get_node_count(argv[1]), argv[2]);
  return 0;
}
/***
 * 検索用ソースコード
 * ※ 五回目、六回目で既に出てきたクラスは省略されている
 *
 * ファイル名: louds_trie_less_node.h
 */
#ifndef LOUDS_TRIE_LESS_NODE_H
#define LOUDS_TRIE_LESS_NODE_H

#include "bit_vector.h"
#include <cstdio>

class LoudsTrieLessNode {
public:
  static const unsigned ROOT_NODE=0;
  static const unsigned NOT_FOUND=0xFFFFFFFF;

  LoudsTrieTail(FILE *f) : r0(f), r1(f), eos_key_num(f), tail_key_num(f), tail_id(f), nc(f), tind(f), tail(f) {}
  
  // キー検索関数
  // 後で拡張するのために、内部的に実際の処理を行う関数を呼び出すようにしている
  unsigned find(const char* key) const {
    return find_liner(key, ROOT_NODE);
  }
  
private:
  // キー検索関数
  unsigned find_liner(const char* key, unsigned root_node) const {
    unsigned node = first_child(root_node);
    while(node != NOT_FOUND) {
      if(node_char(node) == *key) {
        key++;
        
        if(*key=='\0')
          return eos_key_num.is_0bit(node) ? NOT_FOUND : eos_node_id(node);

        // TAIL配列へ遷移する必要があるか?
        if(tail_id.is_0bit(node)==false) {
          const char* s = tail.data()+tind[tail_id.rank1(node)];
           
          // TAIL配列内で終端
          if(tail_key_num.is_0bit(node)==false)
            return strcmp(key,s)==0 ? tail_node_id(node) : NOT_FOUND;

          // TAIL配列から復帰する必要有り
          unsigned len=strlen(s);
          if(strncmp(key, s, len) != 0)
            return NOT_FOUND;
          key += len;
        }
        
        node = first_child(node);   // 子ノード取得
      } else {
        node = next_sibling(node);  // 兄弟ノード取得
      }
    }
    return NOT_FOUND;
  }

private:
  bool is_leaf(unsigned node)   const { return r0.is_0bit(node); }
  char node_char(unsigned node) const { return nc[node]; }
  unsigned next_sibling(unsigned node) const { return r1.is_0bit(node)   ? node+1 : NOT_FOUND; }
  unsigned prev_sibling(unsigned node) const { return r1.is_0bit(node-1) ? node-1 : NOT_FOUND; }
  unsigned first_child(unsigned node)const { return is_leaf(node) ? NOT_FOUND : r1.select1(r0.rank1(node))+1; }
  unsigned last_child(unsigned node) const { return is_leaf(node) ? NOT_FOUND : r1.select1(r0.rank1(node)+1); }
  unsigned parent(unsigned node)     const { return node==ROOT_NODE ? NOT_FOUND : r0.select1(r1.rank1(node-1)); }

  // キーのID取得方法には、少し変更あり
  unsigned eos_node_id(unsigned node)  const { return eos_key_num.rank1(node) + tail_key_num.rank1(node-1); }
  unsigned tail_node_id(unsigned node) const { return eos_key_num.rank1(node) + tail_key_num.rank1(node); }  

private:
  const BitVector r0;
  const BitVector r1;
  const BitVector eos_key_num;
  const BitVector tail_key_num;
  const BitVector tail_id;
  const NodeChars nc;
  const TailIndex tind;
  const Tail      tail;
};

#endif

louds-trie-less-node-opt

louds-trie-less-node-optは、基本louds-trie-less-nodeと同じだが、キー検索関数に以下の二つの変更が施されている。

  • first_childメソッド呼び出しのメモ化
    • 先頭からN番目までのノードに対するfirst_child呼び出し結果を、コンストラクタで保存しておく ※ Nは任意の定数
    • first_childメソッドは、select1とrank1を一回ずつ呼び出しているので、その分のコストを削除できる
  • 子ノード検索方法を、今までの線形検索から二分検索に(部分的に)変更
    • select1呼び出しの削減に比べて、劇的な速度向上はないが、最悪のケース*5での処理時間の悪化を抑えることが可能という意味で有用
    • "部分的に"とは、メモ化と併用することを想定しており、N番目以降のノードに対しては従来の線形検索を行うため
      • メモ化と併用しないと(first_childとlast_childを呼び出す必要があるため)かえって遅くなってしまう
      • また、(たぶん)大抵のtrieではレベルが深くなるに従って、個々のノードに属する子どもの数は少なくなるので、Nを十分大きくとれば二分検索よりも線形検索の方が高速になると予想される

変更分のソース。

/***
 * 検索用ソースコード
 * ※ louds-trie-less-nodeから変更がないメソッド等は省略されている
 *
 * ファイル名: louds_trie_less_node.h
 */
#ifndef LOUDS_TRIE_LESS_NODE_H
#define LOUDS_TRIE_LESS_NODE_H
class LoudsTrieLessNodeOpt {
public:
  static const unsigned MEMO_LIMIT=100000;  // メモ化対象のノード数  ※ 簡単のために今回は固定

private:
  unsigned memo_child_beg[MEMO_LIMIT];  // first_childメソッドの呼び出しを保存する配列
  unsigned memo_child_end[MEMO_LIMIT];  // last_child メソッドの呼び出しを保存する配列  ※ 実際には1byte配列で十分

public:
  LoudsTrieTail(FILE *f) : r0(f), r1(f), eos_key_num(f), tail_key_num(f), tail_id(f), nc(f), tind(f), tail(f) {
    // メモ化
    unsigned limit = (r1.length() < MEMO_LIMIT) ? r1.length() : MEMO_LIMIT;
    for(unsigned i=0; i < limit; i++) {
      memo_child_beg[i] = first_child(i);
      memo_child_end[i] = last_child(i);
      if(memo_child_end[i] != NOT_FOUND)
        memo_child_end[i]++;
    }
  }
  
  unsigned find(const char* key) const {
    // find_binaryを呼び出す
    return find_binary(key, ROOT_NODE);
  }
  
private:
  // キー検索関数 (二分検索版)
  unsigned find_binary(const char* key, unsigned root_node) const {
    unsigned node = bin_srch(*key, root_node);  // 子を検索
    while(node != NOT_FOUND) {

      // 以下から次の改行までは、find_linerと同様
      key++;
      if(*key=='\0')
        return eos_key_num.is_0bit(node) ? NOT_FOUND : eos_node_id(node);
      if(tail_id.is_0bit(node)==false) {
        const char* s = tail.data()+tind[tail_id.rank1(node)];
        if(tail_key_num.is_0bit(node)==false)
          return strcmp(key,s)==0 ? tail_node_id(node) : NOT_FOUND;      
        unsigned len=strlen(s);
        if(strncmp(key, s, len) != 0)
          return NOT_FOUND;
        key += len;
      }
      
      // メモ化対象範囲を越えたら、線形検索に移行
      if(node >= MEMO_LIMIT)
        return find_liner(key, node);

      node = bin_srch(*key, node);     // 子を検索
    }
    return NOT_FOUND;    
  }

  // 二分検索を使って、子ノードを検索する
  unsigned bin_srch(char ch, unsigned node) const {
    unsigned beg=memo_child_beg[node];
    unsigned end=memo_child_end[node];
    
    while(end-beg >= 8) {
      unsigned m = beg+(end-beg)/2;
      unsigned char uc1 = (unsigned char)node_char(m);
      unsigned char uc2 = (unsigned char)ch;
      
      if(uc1 == uc2)
        return m;
      if(uc1 < uc2)
        beg = m;
      else
        end = m;
    }
    
    for(unsigned i=beg; i < end; i++)
      if(ch==node_char(i))
        return i;
    
    return NOT_FOUND;    
  }
};
#endif

感想

これで試したいことはだいたい試し終えた。
後はbit-vectorの実装を変更/チューニングすることで、1〜2秒程度なら速くできるかもしれないが、もう劇的に速くなるといったようなことはないのではないかと思う。
現時点で(キーセットに大きく依存するが)DoubleArray(というかdoar)に比べて、サイズが半分で速度が1/4。
悪くはない。

ただ、これからdoarのサイズ縮小も試すつもりなので、もしそれが上手くいってしまったら、trieの実装においてLOUDSを使う意義はなくなってしまうな、とも思う。

ソースコード

現時点までに作成したソースコード一式は、もう少し整理してどこかにアップする予定。
と思っていたけど時間がないので、中止。余裕(or 機会)があったらいつか...(2010/07/12)

*1:例えば、極端な例として、"abcde"という文字列一つからなるtrieがあるとする。
TAIL配列を使えばキーの検索が一回の文字列比較で済む(select1呼び出しは0回)が、通常のtrieではselect1を5回呼び出す必要がある(成功検索の場合)

*2:この名前は適切ではない

*3:ただし、実際に検索速度が向上するかどうかはキーセットに依存すると思われる。例えば、無分岐ノードが存在しない場合はTAIL配列があろうがなかろうが関係ないし、無分岐ノードのフラグメント化(?)が激しい場合はTAIL配列とtrieノードを行き来するオーバヘッドの方が重くなる可能性もある

*4:六回目を参照

*5:Ex. 子ノードが密な木での不成功検索