ソート済みのリストに対する破壊的マージソートの改良

以前に載せたマージソート(をベースとしたもの)をSBCL(1.0.58)にコミットしてくれたPaul Khuongさんが、こんな記事を書いていて、なるほどなー、と思ったので、表題に関係する部分を参考にさせて貰って変更前後での比較を行ったメモ。

オリジナルのマージソート

まず、SBCL(1.0.58)のリストに対する破壊的マージソートの実装*1:

;; 二つのソート済みリストのマージ関数
(declaim (inline merge-lists*))
(defun merge-lists* (head list1 list2 test key &aux (tail head))
  (declare (type cons head list1 list2)
           (type function test key)
           (optimize speed))
  (macrolet ((merge-one (l1 l2)
               `(progn
                  (setf (cdr tail) ,l1
                        tail       ,l1)
                  (let ((rest (cdr ,l1)))
                    (cond (rest
                           (setf ,l1 rest))
                          (t
                           (setf (cdr ,l1) ,l2)
                           (return (cdr head))))))))
    (loop
     (if (funcall test (funcall key (car list2))  ; this way, equivalent
                       (funcall key (car list1))) ; values are first popped
         (merge-one list2 list1)                  ; from list1
         (merge-one list1 list2)))))

;; 実行
(merge-lists* '(:head) '(1 3 5) '(2 4 6) #'< #'identity))
=> (1 2 3 4 5 6)
;; リストのマージソート関数
(declaim (inline stable-sort-list))
(defun stable-sort-list (list test key &aux (head (cons :head list)))
  (declare (type list list)
           (type function test key)
           (dynamic-extent head))
  (labels ((recur (list size)
             (declare (optimize speed)
                      (type cons list)
                      (type (and fixnum unsigned-byte) size))
             (if (= 1 size)
                 (values list (shiftf (cdr list) nil))
                 (let ((half (ash size -1)))
                   (multiple-value-bind (list1 rest)
                       (recur list half)
                     (multiple-value-bind (list2 rest)
                         (recur rest (- size half))
                       (values (merge-lists* head list1 list2 test key)
                               rest)))))))
    (when list
      (values (recur list (length list))))))

;; 実行
(stable-sort-list '(8 73 2 40 0 3) #'< #'identity)
=> (0 2 3 8 40 73)

何種類かデータを用意して実行時間を計測:

;;; 計測用データ
;; 1] 400万要素のソート済みリスト
(defparameter *sorted-list* (loop FOR i FROM 0 BELOW 4000000 COLLECT i))

;; 2] 400万要素の逆順ソート済みリスト
(defparameter *reverse-sorted-list* (reverse *sorted-list*))

;; 3] 400万要素のほぼソート済みリスト1  ※ 千要素に一つがランダムな値
(defparameter *nearly-sorted-list1* (loop FOR i FROM 0 BELOW 4000000
                                         COLLECT (if (zerop (random 1000))
                                                     (random 4000000)
                                                   i)))

;; 4] 400万要素のほぼソート済みリスト2  ※ 複数のソート済みリストが連結
(defparameter *nearly-sorted-list2* (loop REPEAT 4 APPEND (loop FOR i FROM 0 BELOW 1000000 COLLECT i)))

;; 5] 400万要素のランダムなリスト
(defparameter *random-list* (loop REPEAT 4000000 COLLECT (random most-positive-fixnum)))


;;; 計測用マクロ
(defmacro sort-time (sort-fn-name list)
  `(let ((list~ (copy-list ,list)))
     (declare (optimize (speed 3) (safety 0)))
     (time (progn (,sort-fn-name list~ #'< #'identity)
                  t))))


;;; 計測
;; 1] ソート済みリスト
(sort-time stable-sort-list *sorted-list*)
Evaluation took:
  0.254 seconds of real time  ; 0.254秒
  0.252017 seconds of total run time (0.248016 user, 0.004001 system)
  99.21% CPU
  508,247,464 processor cycles
  0 bytes consed
=> T

;; 2] 逆順ソート済みリスト
(sort-time stable-sort-list *reverse-sorted-list*)
Evaluation took:
  0.235 seconds of real time  ; 0.235秒
  0.232015 seconds of total run time (0.232015 user, 0.000000 system)
  98.72% CPU
  468,869,834 processor cycles
  0 bytes consed
=> T

;; 3] ほぼソート済みリスト1  ※ 千要素に一つがランダムな値
(sort-time stable-sort-list *nearly-sorted-list1*)
Evaluation took:
  0.348 seconds of real time  ; 0.348秒
  0.348023 seconds of total run time (0.344022 user, 0.004001 system)
  100.00% CPU
  694,968,622 processor cycles
  0 bytes consed
=> T

;; 4] ほぼソート済みリスト2  ※ 複数のソート済みリストが連結
(sort-time stable-sort-list *nearly-sorted-list2*)
Evaluation took:
  0.271 seconds of real time  ; 0.271秒
  0.272017 seconds of total run time (0.272017 user, 0.000000 system)
  100.37% CPU
  538,952,732 processor cycles
  0 bytes consed
=> T

;; 5] ランダムリスト
(sort-time stable-sort-list *random-list*)
Evaluation took:
  2.171 seconds of real time  ; 2.171秒
  2.168135 seconds of total run time (2.160135 user, 0.008000 system)
  99.86% CPU
  4,332,215,938 processor cycles
  0 bytes consed
=> T

ソート済みのリストに対する改良を加えたマージソート

変更後のマージソート関数: ※ 変更内容はコメントを参照

;; 改良版マージソート関数
;; - fast-merge-lists*関数が追加されたこと以外は、もともとの関数とほとんど同様
;; - fast-merge-lists*関数は要素の範囲が重複しない二つのリストをO(1)でマージ可能
(declaim (inline stable-sort-list2))
(defun stable-sort-list2 (list test key &aux (head (cons :head list)))
  (declare (type list list)
           (type function test key)
           (dynamic-extent head))
        
           ;; マージ対象の二つのリスト内の片方が、もう片方に完全に先行している場合は、
           ;; 各要素の比較などは省略して、末尾のcdrの更新のみを行う。
  (labels ((fast-merge-lists* (try-fast-merge? list1 tail1 list2 tail2 rest)
             (when try-fast-merge?
                      ;; list1がlist2に完全に先行: (list1 .. tail1) <= (list2 .. tail2)
               (cond ((not (funcall test (funcall key (car list2))
                                         (funcall key (car tail1))))
                      (setf (cdr tail1) list2)
                      (return-from fast-merge-lists* (values list1 tail2 rest)))

                      ;; list2がlist1に完全に先行: (list2 .. tail2) < (list1 .. tail1)
                     ((funcall test (funcall key (car tail2))
                                    (funcall key (car list1)))
                      (setf (cdr tail2) list1)
                      (return-from fast-merge-lists* (values list2 tail1 rest)))))
             
             ;; その他: 通常のマージ
             (values (merge-lists* head list1 list2 test key)
                     (if (null (cdr tail1))
                         tail1
                       tail2)
                     rest))
                  
            ;; トップダウンマージリスト関数: リストの末尾を管理するようになったのとfast-merge-lists*関数を使うようになったこと以外は変更なし            
            (recur (list size)
             (declare (optimize speed)
                      (type cons list)
                      (type (and fixnum unsigned-byte) size))
             (if (= 1 size)
                 (values list list (shiftf (cdr list) nil))
                 (let ((half (ash size -1)))
                   (multiple-value-bind (list1 tail1 rest)
                       (recur list half)
                     (multiple-value-bind (list2 tail2 rest)
                         (recur rest (- size half))
                       (fast-merge-lists* (>= size 8)  ; オーバヘッドを少なくするために、一定サイズ以上のリストに対してのみ適用を試みる
                                          list1 tail1 list2 tail2 rest)))))))
    (when list
      (values (recur list (length list))))))

;; 実行
(stable-sort-list2 '(8 73 2 40 0 3) #'< #'identity)
=> (0 2 3 8 40 73)

処理時間計測:

;; 1] ソート済みリスト
(sort-time stable-sort-list2 *sorted-list*)
Evaluation took:
  0.086 seconds of real time  ; 0.086秒  (変更前: 0.254秒)
  0.088005 seconds of total run time (0.088005 user, 0.000000 system)
  102.33% CPU
  171,845,432 processor cycles
  0 bytes consed
=> T

;; 2] 逆順ソート済みリスト
(sort-time stable-sort-list2 *reverse-sorted-list*)
Evaluation took:
  0.087 seconds of real time  ; 0.0.87秒  (変更前: 0.235秒)
  0.088006 seconds of total run time (0.088006 user, 0.000000 system)
  101.15% CPU
  173,196,084 processor cycles
  0 bytes consed
=> T

;; 3] ほぼソート済みリスト1  ※ 千要素に一つがランダムな値
(sort-time stable-sort-list2 *nearly-sorted-list1*)
Evaluation took:
  0.293 seconds of real time  ; 0.293秒  (変更前: 0.348秒)
  0.292019 seconds of total run time (0.292019 user, 0.000000 system)
  99.66% CPU
  585,393,530 processor cycles
  0 bytes consed
=> T

;; 4] ほぼソート済みリスト2  ※ 複数のソート済みリストが連結
(sort-time stable-sort-list2 *nearly-sorted-list2*)
Evaluation took:
  0.122 seconds of real time  ; 0.122秒  (変更前: 0.271秒)
  0.120007 seconds of total run time (0.116007 user, 0.004000 system)
  98.36% CPU
  242,403,024 processor cycles
  0 bytes consed
=> T

;; 5] ランダムリスト
(sort-time stable-sort-list2 *random-list*)
Evaluation took:
  2.193 seconds of real time  ; 2.193秒  (変更前: 2.171秒)
  2.192138 seconds of total run time (2.164136 user, 0.028002 system)
  99.95% CPU
  4,376,336,316 processor cycles
  0 bytes consed
=> T

完全にランダムなリストに対するソートは心なしか改良版の方が(ごく若干)遅くなっているように思うが、入力リストにソート済みの部分が多ければ多いほど、確実に改良版の方が速くなっている。
確かに、二つのリストをマージする場合、それぞれの領域が独立しているなら、片方の先頭要素ともう片方の末尾要素を比較するだけで、リスト全体を完全に順序づけ可能なんだけど、自分が実装方法を考えている時には、そのことに思い至らなかった。
なるほどなー。

*1:sbcl-1.0.58/src/code/sort.lisp より引用

エラトステネスの篩

loop*1を使って、エラトステネスの篩を実装してみたメモ。
以下、処理系にはSBCLのver1.0.54(x86-64bit)を使用。

;; 引数nまでの範囲の素数のシーケンス(ジェネレータ)を作成する
(declaim (inline make-prime-sequence))
(defun make-prime-sequence (n)
  (let ((arr (make-array (1+ n) :element-type 'bit :initial-element 1)))
    (flet ((prime? (i) (= (bit arr i) 1))       
           (not-prime! (i) (setf (bit arr i) 0))) 
      (declare (inline prime? not-prime!))

      (loop:each (lambda (i)
                   (when (prime? i)
                     (loop:each #'not-prime! (loop:from (* i 2) :to n :by i))))
                 (loop:from 2 :to (floor (sqrt n))))
    
      (loop:filter #'prime? (loop:from 2 :to n)))))

;;; 実行例
;; 100以下の素数
(loop:collect (make-prime-sequence 100))
=> (2 3 5 7 11 13 17 19 23 29 31 37 41 43 47 53 59 61 67 71 73 79 83 89 97)

;; 1001から1010番目の素数
(loop:collect (loop:take 10 (loop:drop 1000 (make-prime-sequence 10000000))))
=> (7927 7933 7937 7949 7951 7963 7993 8009 8011 8017)

通常のループ(loopマクロ)を使った場合との速度比較。

;; 比較用に素数の合計値を求める関数を用意
(defun prime-sum1 (n)
  (declare (fixnum n)
           (optimize (speed 3) (safety 0) (debug 0)))
  (loop:sum #'identity (make-prime-sequence n)))

;; 一億以下の素数の合計値
(time (prime-sum1 100000000))
Evaluation took:
  1.675 seconds of real time  ; 1.675秒
  1.676105 seconds of total run time (1.676105 user, 0.000000 system)
  100.06% CPU
  3,342,591,038 processor cycles
  12,500,032 bytes consed
=> 279209790387276
;; loopマクロ版
(defun prime-sum2 (n)
  (declare (fixnum n)
           (optimize (speed 3) (safety 0) (debug 0)))
  (let ((arr (make-array (1+ n) :element-type 'bit :initial-element 1)))
    (flet ((prime? (i) (= (bit arr i) 1))
           (not-prime! (i) (setf (bit arr i) 0)))
      (declare (inline prime? not-prime!))

      (loop FOR i fixnum FROM 2 TO (floor (sqrt n))
            WHEN (prime? i)
        DO
        (loop FOR j fixnum FROM (* i 2) TO n BY i
          DO
          (not-prime! j)))

      (loop WITH sum OF-TYPE (unsigned-byte 64)
            FOR i fixnum FROM 2 TO n
            WHEN (prime? i)
        DO (incf sum i)
        FINALLY (return sum)))))

;; 一億以下の素数の合計値
(time (prime-sum2 100000000))
Evaluation took:
  1.476 seconds of real time  ; 1.476秒
  1.472092 seconds of total run time (1.468092 user, 0.004000 system)
  99.73% CPU
  2,944,592,020 processor cycles
  12,500,032 bytes consed
=> 279209790387276

ループ処理を関数型っぽく書いてみる(2)

前回の続き。
githubにあるloopの簡易版を載せておく。

基本的な考え方

基本的なJava等のIteratorと似た*1インタフェースを通してループ処理を実現している。
異なるのは全ての関数をinline展開可能にすることで、同等のループを非関数型的に書いた場合と同じくらいに、コンパイラが最適化を行ってくれることを期待していることくらい。
後は、SBCLの最適化の制限上、構造体等は使用せず、極力lambdaで全てを表現するようにしている。

実装

まず、loopパッケージ用のシーケンス生成関数。

;; 数値の範囲を表現するシーケンス
(declaim (inline from))
(defun from (start &key to (by 1))  ; toがnilなら無限シーケンス
  ;; 全体をlambdaで囲む。このlambdaの呼び出しがシーケンスの初期化処理に相当する。
  (lambda () 
    (let ((cur start))
      ;; 以下の三つの関数を呼び出し元に返す
      (values (lambda () (incf cur by))          ; 1] 値更新関数
              (lambda () (and to(> cur to)))     ; 2] 終端判定関数
              (lambda (fn) (funcall fn cur)))))) ; 3] ループの本体実行関数

;; リスト用
(declaim (inline for-list))
(defun for-list (list)
  (lambda ()
    (let ((head list)) ; 初期値
      (values (lambda () (setf head (cdr head)))         ; 1] 値更新関数
              (lambda () (endp head))                    ; 2] 終端判定関数
              (lambda (fn) (funcall fn (car head)))))))  ; 3] ループの本体実行関数

;; 実行
> (from 1 :to 10)
#<CLOSURE (LAMBDA () :IN FROM) {1007855D3B}>

> (funcall (from 1 :to 10))
#<CLOSURE (LAMBDA () :IN FROM) {10078C432B}>   ; 1] 値更新関数
#<CLOSURE (LAMBDA () :IN FROM) {10078C434B}>   ; 2] 終端判定関数
#<CLOSURE (LAMBDA (FN) :IN FROM) {10078C436B}> ; 3] ループの本体実行関数

> (setf (values next end? call-body) (funcall (from 1 :to 10)))
> (funcall call-body (lambda (x) (list :val x)))
=> (:val 1)

> (funcall next)
> (funcall call-body (lambda (x) (list :val x)))
=> (:val 2)

上の関数で生成されたシーケンスを走査する関数。

;; 一番基本となる走査関数
(declaim (inline each))
(defun each (fn seq)
  (multiple-value-bind (next-fn end-fn call-fn) (funcall seq)  ; シーケンス初期化
    (loop UNTIL (funcall end-fn)   ; 終端判定
          DO (funcall call-fn fn)  ; 本体実行
             (funcall next-fn))))  ; 値更新

;; 畳み込み関数  ※ reduceはclパッケージとそれと名前が衝突するので、ここではfoldにしている
(declaim (inline fold))
(defun fold (fn init seq)
  (let ((acc init))
    (each (lambda (x)
            (setf acc (funcall fn acc x)))
          seq)
    acc))

;; シーケンスを集めたリストを返す
(declaim (inline collect))
(defun collect (seq)
  (nreverse (fold (lambda (acc x) (cons x acc))
                    '()
                    seq)))

;; 実行
> (collect (from 1 :to 20 :by 3))
=> (1 4 7 10 13 16 19)

; 合計値計算
> (fold (lambda (acc x) (+ acc x))
        0
        (from 1 :to 20 :by 3))
=> 70

mapとかfilterとかシーケンスを加工/制御する関数。

;; map関数
(declaim (inline map-seq))
(defun map-seq (map-fn seq)
  ;; ソースとなるシーケンスの情報を取得し、それをラップして返す
  (multiple-value-bind (next-fn end-fn call-fn) (funcall seq)
    (lambda ()
      (values next-fn  ; 値更新関数と終端判定関数はそのまま
              end-fn
              (lambda (body-fn)
                ;; 本体呼び出し前に、マップ処理用関数を差し込む
                (funcall call-fn (lambda (val) (funcall body-fn (funcall map-fn val)))))))))

;; filter関数: (funcall pred-fn val)がnilとなる要素をスキップする
(declaim (inline filter))
(defun filter (pred-fn seq)
  ;; ソースとなるシーケンスの情報を取得し、それをラップして返す
  (multiple-value-bind (next-fn end-fn call-fn) (funcall seq)
    (lambda ()
      (values next-fn ; 値更新関数と終端判定関数はそのまま
              end-fn
              (lambda (body-fn)
                ;; 本体呼び出し前に、フィルター処理用関数を差し込む
                (funcall call-fn (lambda (val)
                                   (unless (funcall pred-fn val)
                                     (funcall body-fn val)))))))))

;; 実行
; 二乗する
> (collect (map-seq (lambda (x) (* x x)) (from 1 :to 20)))
=> (1 4 9 16 25 36 49 64 81 100 121 144 169 196 225 256 289 324 361 400)

; 奇数の値だけフィルタして二乗する
> (collect (map-seq (lambda (x) (* x x)) (filter #'oddp (from 1 :to 20))))
=> (1 9 25 49 81 121 169 225 289 361)

これらの関数群を組み合わせてループ処理を表現すると、そこそこ良い感じのコードを生成してくれる。

;; 上で定義した関数群を用いたsum関数
;; - startからendの範囲の奇数値を-10した合計値を返す
(defun sum1 (start end)
 (declare (fixnum start end)
          (optimize (speed 3) (safety 0))
          (sb-ext:unmuffle-conditions sb-ext:compiler-note))
 (fold (lambda (acc n)
         (the fixnum (+ acc n)))
       0
       (map-seq (lambda (x) (- x 10)) 
                (filter #'oddp (from start :to end)))))

;; loopマクロを使用したsum関数
(defun sum2 (start end)
 (declare (fixnum start end)
          (optimize (speed 3) (safety 0))
          (sb-ext:unmuffle-conditions sb-ext:compiler-note))
 (loop WITH total fixnum = 0
       FOR i FROM start TO end
       WHEN (oddp i)
   DO (let ((n (- i 10)))
        (declare (fixnum n))
        (incf total n))
   FINALLY (return total)))

;; 一億要素に対するループ
> (time (sum1 1 100000000))
Evaluation took:
  0.134 seconds of real time  ; 0.134秒
  0.136009 seconds of total run time (0.136009 user, 0.000000 system)
  101.49% CPU
  267,335,373 processor cycles
  0 bytes consed
=> 2499999500000000

> (time (sum2 1 100000000))
Evaluation took:
  0.131 seconds of real time  ; 0.131秒
  0.132008 seconds of total run time (0.132008 user, 0.000000 system)
  100.76% CPU
  261,630,697 processor cycles
  0 bytes consed
=> 2499999500000000

;; disassemble結果
> (disassemble #'sum1)
; disassembly for SUM1
; 07AA1C28:       31D2             XOR EDX, EDX               ; no-arg-parsing entry point
;       2A:       EB1B             JMP L2
;       2C:       90               NOP
;       2D:       90               NOP
;       2E:       90               NOP
;       2F:       90               NOP
;       30: L0:   488BC1           MOV RAX, RCX
;       33:       488D1C4500000000 LEA RBX, [RAX*2]
;       3B:       4883E302         AND RBX, 2
;       3F:       4885DB           TEST RBX, RBX
;       42:       750E             JNE L3
;       44: L1:   48FFC1           INC RCX
;       47: L2:   4839F9           CMP RCX, RDI
;       4A:       7EE4             JLE L0
;       4C:       488BE5           MOV RSP, RBP
;       4F:       F8               CLC
;       50:       5D               POP RBP
;       51:       C3               RET
;       52: L3:   4883E80A         SUB RAX, 10
;       56:       48D1FA           SAR RDX, 1
;       59:       4801C2           ADD RDX, RAX
;       5C:       48D1E2           SHL RDX, 1
;       5F:       EBE3             JMP L1

> (disassemble #'sum2)
; disassembly for SUM2
; 07EF0DB8:       31D2             XOR EDX, EDX               ; no-arg-parsing entry point
;       BA:       EB2A             JMP L2
;       BC:       90               NOP
;       BD:       90               NOP
;       BE:       90               NOP
;       BF:       90               NOP
;       C0: L0:   488D044D00000000 LEA RAX, [RCX*2]
;       C8:       4883E002         AND RAX, 2
;       CC:       4885C0           TEST RAX, RAX
;       CF:       7412             JEQ L1
;       D1:       488D044D00000000 LEA RAX, [RCX*2]
;       D9:       488BD8           MOV RBX, RAX
;       DC:       4883EB14         SUB RBX, 20
;       E0:       4801DA           ADD RDX, RBX
;       E3: L1:   48FFC1           INC RCX
;       E6: L2:   4839F9           CMP RCX, RDI
;       E9:       7ED5             JLE L0
;       EB:       488BE5           MOV RSP, RBP
;       EE:       F8               CLC
;       EF:       5D               POP R

最後は複数シーケンスをまとめるzip関数。
これを使うと表現力はだいぶ上がるけど、性能は若干劣化する。

;; 二つのシーケンスをまとめる
(declaim (inline zip))
(defun zip (loop1 loop2 &aux (undef (gensym)))
  (multiple-value-bind (next-fn1 end-fn1 call-fn1) (funcall loop1)
    (multiple-value-bind (next-fn2 end-fn2 call-fn2) (funcall loop2)
      (let ((memo1 undef)
            (memo2 undef))
        (lambda ()
          (values (lambda () ; 値更新
                    (when (eq memo1 undef) (funcall next-fn1))
                    (when (eq memo2 undef) (funcall next-fn2)))

                  (lambda () ; 終端判定
                    (or (funcall end-fn1) (funcall end-fn2)))

                  (lambda (body-fn)  ; 本体呼び出し
                    ;; それぞれのシーケンスの次の値を取得する
                    ;; (次の値がfilterでスキップされた場合は memoX はundefのままになる)
                    (when (eq memo1 undef)
                      (funcall call-fn1 (lambda (val) (setf memo1 val))))

                    (when (eq memo2 undef)
                      (funcall call-fn2 (lambda (val) (setf memo2 val))))
    
                    ;; 両方のシーケンスの値が取得できたら、本体を呼び出す
                    (when (not (or (eq memo1 undef)
                                   (eq memo2 undef)))
                      (funcall fn (list memo1 memo2))  ; XXX: listで二つの値をまとめるのはconsingが発生するので効率が悪い (そのためloopパッケージでは、多引数を受け取るmapやfilterを用意している)
                      (setf memo1 undef
                            memo2 undef)))))))))

;; 実行
> (collect
    (zip (filter (lambda (n) (and (oddp n)  (zerop (mod n 3)))) (from 1))           ; 奇数かつ三の倍数
         (filter (lambda (n) (and (evenp n) (zerop (mod n 5)))) (from 1 :to 100)))) ; 偶数かつ五の倍数
=> ((3 10) (9 20) (15 30) (21 40) (27 50) (33 60) (39 70) (45 80) (51 90) (57 100))

zipはもう少し上手く実装したいところだけど、それでも関数型っぽく書いても実用上十分な性能がでるループ処理が実現できそうなことが分かったので、結構満足している。

*1:似てないかも

Gomokuの形態素解析部をScalaで実装してみた

ここ数日はScalaのコップ本を読んでいて、何かまとまったプログラムをScalaで書いてみたくなったのでGomoku(Java形態素解析器。ver0.0.6)Scalaで実装してみた*1
github: scala-gomoku(ver0.0.1)

以下、使用例とJava版/Scala版の簡単な比較メモ。

使用例

$ scala -cp scala-gomoku-0.0.1.jar

// インタプリタ起動 & パッケージインポート
scala> import net.reduls.scala.gomoku._

// 分かち書き
scala> Tagger.wakati("Scalaはオブジェクト指向言語と関数型言語の特徴を統合したマルチパラダイムのプログラミング言語である。")
res0: List[String] = 
        List(Scala, は, オブジェクト, 指向, 言語, と, 関数, 型, 言語, の, 特徴, を, 統合, し, た, マルチパラダイム, の, プログラミング, 言語, で, ある, 。)

// 形態素解析
scala> Tagger.parse("Scalaはオブジェクト指向言語と関数型言語の特徴を統合したマルチパラダイムのプログラミング言語である。")
res1: List[net.reduls.scala.gomoku.Morpheme] =
         List(Morpheme(Scala,名詞,固有名詞,組織,*,*,*,0), Morpheme(は,助詞,係助詞,*,*,*,*,5), Morpheme(オブジェクト,名詞,一般,*,*,*,*,6), Morpheme(指向,名詞,サ変接続,*,*,*,*,12), Morpheme(言語,名詞,一般,*,*,*,*,14), 
              Morpheme(と,助詞,並立助詞,*,*,*,*,16), Morpheme(関数,名詞,一般,*,*,*,*,17), Morpheme(型,名詞,接尾,一般,*,*,*,19), Morpheme(言語,名詞,一般,*,*,*,*,20), Morpheme(の,助詞,連体化,*,*,*,*,22), Morpheme(特徴,名詞,一般,*,*,*,*,23), 
              Morpheme(を,助詞,格助詞,一般,*,*,*,25), Morpheme(統合,名詞,サ変接続,*,*,*,*,26), Morpheme(し,動詞,自立,*,*,サ変・スル,連用形,28), Morpheme(た,助動詞,*,*,*,特殊・タ,基本形,29), Morpheme(マルチパラダイム,名詞,一般,*,*,*,*,30), 
              Morpheme(の,助詞,連体化,*,*,*,*,38), Morpheme(プログラミング,名詞,サ変接続,*,*,*,*,39), Morpheme(言語,名詞,一般,*,*,*,*,46), Morpheme(で,助動詞,*,*,*,特殊・ダ,連用形,48), Morpheme(ある,助動詞,*,*,*,五段・ラ行アル,基本形,49), Morpheme(。,記号,句点,*,*,*,*,51))

// 名詞のみ取り出し
scala> for(m <- res1 if m.feature.startsWith("名詞")) yield m.surface
res2: List[String] = 
        List(Scala, オブジェクト, 指向, 言語, 関数, 型, 言語, 特徴, 統合, マルチパラダイム, プログラミング, 言語)

ソースコード行数

形態素解析部のソースコードの行数比較。

Java:

$ cd gomoku-0.0.6-src
$ wc -l `find . -name '*.java'`
  117 ./analyzer/src/net/reduls/gomoku/Tagger.java
   12 ./analyzer/src/net/reduls/gomoku/Morpheme.java
   23 ./analyzer/src/net/reduls/gomoku/util/ReadLine.java
   83 ./analyzer/src/net/reduls/gomoku/util/Misc.java
   38 ./analyzer/src/net/reduls/gomoku/bin/Gomoku.java
   32 ./analyzer/src/net/reduls/gomoku/dic/Unknown.java
   72 ./analyzer/src/net/reduls/gomoku/dic/Char.java
   23 ./analyzer/src/net/reduls/gomoku/dic/WordDic.java
   61 ./analyzer/src/net/reduls/gomoku/dic/SurfaceId.java
   43 ./analyzer/src/net/reduls/gomoku/dic/Morpheme.java
   26 ./analyzer/src/net/reduls/gomoku/dic/PartsOfSpeech.java
   23 ./analyzer/src/net/reduls/gomoku/dic/ViterbiNode.java
   26 ./analyzer/src/net/reduls/gomoku/dic/Matrix.java
  579 合計

Scala:

$ cd scala-gomoku-0.0.1-src
$ wc -l `find . -name '*.scala'`
   3 ./src/net/reduls/scala/gomoku/Morpheme.scala
  27 ./src/net/reduls/scala/gomoku/bin/Gomoku.scala
  15 ./src/net/reduls/scala/gomoku/dic/Matrix.scala
  13 ./src/net/reduls/scala/gomoku/dic/PartsOfSpeech.scala
  18 ./src/net/reduls/scala/gomoku/dic/Morpheme.scala
  22 ./src/net/reduls/scala/gomoku/dic/Char.scala
  32 ./src/net/reduls/scala/gomoku/dic/Util.scala
   9 ./src/net/reduls/scala/gomoku/dic/ViterbiNode.scala
  39 ./src/net/reduls/scala/gomoku/dic/SurfaceId.scala
  30 ./src/net/reduls/scala/gomoku/dic/Unknown.scala
  15 ./src/net/reduls/scala/gomoku/dic/WordDic.scala
  56 ./src/net/reduls/scala/gomoku/Tagger.scala
 279 合計

Scala版はJava版に対して、おおよそ半分程度の行数。

処理速度

以下のようなベンチマークスクリプトを書いて、両者の処理速度を比較してみた。

// ファイル名: Benchmark.scala

import scala.testing.Benchmark
import net.reduls.scala.gomoku.{Tagger=>ScalaTagger}
import net.reduls.gomoku.{Tagger=>JavaTagger}
import scala.io.Source

// ベンチマーク用データ: 使用したのは約17MBの日本語テキストデータ
object BenchmarkData {
  val lines = Source.fromFile("/path/to/testdata").getLines.toArray
}

// Scala用のベンチマークオブジェクト
object ScalaGomokuBenchmark extends Benchmark {
  // BenchmarkData.linesの各行を分かち書き
  override def run() { BenchmarkData.lines.foreach(ScalaTagger.wakati _) } 
}

// Scala用のベンチマークオブジェクト
object JavaGomokuBenchmark extends Benchmark {
  override def run() { BenchmarkData.lines.foreach(JavaTagger.wakati _) }
}

// ベンチマーク実行
println("[Data]")
println("  lines: " + BenchmarkData.lines.length)
println("")

val scalaRlt = ScalaGomokuBenchmark.runBenchmark(11).tail
println("[Scala]")
println("  result : " + scalaRlt.mkString(", "))
println("  average: " + (scalaRlt.sum / scalaRlt.length))
println("")

val javaRlt = JavaGomokuBenchmark.runBenchmark(11).tail
println("[Java]")
println("  result : " + javaRlt.mkString(", "))
println("  average: " + (javaRlt.sum / javaRlt.length))
println("")

実行結果:

# Scala: version 2.9.0.1 (OpenJDK 64-Bit Server VM, Java 1.6.0_23).
$ scala -cp scala-gomoku-0.0.1.jar:gomoku-0.0.6.jar Benchmark.scala
[Data]
  lines: 172088  # データの行数(約17万行)

[Scala]
  result : 4529, 4574, 4568, 4540, 4503, 4510, 4523, 4515, 4551, 4531  
  average: 4534  # 平均: 4.534秒

[Java]
  result : 3153, 3111, 3118, 3112, 3102, 3098, 3118, 3130, 3117, 3133
  average: 3119  # 平均: 3.119秒

自分の環境では、Scala版はJava版よりも1.5倍程度遅かった。
※ まだScalaでの効率の良い書き方とかが全然分かっていないので、その辺りを踏まえてちゃんと最適化を行えばもっと差は縮まるかもしれない


書きやすさはScalaの方が全然上だけど、(今回のケースでは)まだJavaに比べると若干遅い感じはする。

*1:形態素解析部のみ、バイナリ辞書データ構築部は未実装

簡易スタック型VM(JITコンパイラもどき)でのフィボナッチ数計算速度

前々々回でスタック型言語をバイトコードコンパイルする部分を、前々回でCommonLispアセンブラによるマシン語生成を、前回でそのアセンブラ上にスタック型言語のラップするところを扱った。
今回はそれらをまとめて、最初に作成したバイトコードインタプリタ(?)VMを、実行時にネイティブコードを生成するJIT(のようなもの)に置き換えて、実行速度を比較してみる。

バイトコード生成部

ここは前々回と全く同様なので省略。
以下にフィボナッチ数計算用のプログラムを再掲しておく。

(pvmc:compile-to-file
 "fib.bc"
 '(
   35    ; fib(35)
   (:addr fib-beg) :call ; fib(25)
   (:addr finish)  :jump
   
   fib-beg
   :dup 2  :less (:addr fib-end) :jump-if  ; if(n < 2) 
   :dup 2  :sub  (:addr fib-beg) :call     ; fib(n - 2)
   :swap 1 :sub  (:addr fib-beg) :call     ; fib(n - 1)
   :add
   fib-end
   :return
   
   finish))
#|
$ od -h fib.bc
0000000 2301 0000 0100 0011 0000 0113 003a 0000
0000020 0911 0201 0000 0800 3901 0000 1200 0109
0000040 0002 0000 0103 0011 0000 0b13 0101 0000
0000060 0300 1101 0000 1300 1402
|#

バイトコード実行(VM)部

前々回はこの部分をC++で作成したが、今回はCommonLispで実装する。
まずはバイトコード実行用の関数の定義。

;;; ファイル名: pvm-execute.lisp

;; アセンブラを読み込んでおく
(asdf:load-system :cl-asm)

;; パッケージ定義
(defpackage pvm-execute
  (:use :common-lisp :sb-alien)
  (:nicknames :pvme)
  (:export execute        ; バイトコードのファイルパスを受け取り実行結果を返す関数
           make-command)) ; バイトコード実行用のコマンドを生成する     
(in-package :pvm-execute)

;; 前回定義した@pushや@pop、その他の関数定義がここにくる
;; ... 省略 ...
;;

;; バイトコードのファイルパスを受け取り評価・実行する
(defun execute (filepath)
  (with-open-file (in filepath :element-type '(unsigned-byte 8))
    (cl-asm:execute (convert-to-executable (read-bytecodes in))
                    (function int))))

;; 入力ストリームからバイトコードを読み込み、cl-asmのニーモニック形式に変換する
(defun read-bytecodes (in)
  (loop FOR pos = (file-position in)
        FOR op = (read-op in)
        WHILE op
    COLLECT
    ;; 各バイトコードを(開始位置 ニーモニック)形式に変換する
    ;; 開始位置は、後にアドレス解決を行う際に使用される
    (list 
     pos
     (ecase op
       (1 `(@int ,(read-int in)))
       (2 '(@add))  ; @で始まる関数群は、前回定義したもの
       (3 '(@sub))
       (4 :mul (error "unsupported")) ; 未対応
       (5 :div (error "unsupported"))
       (6 :mod (error "unsupported"))
       (7 '(@eql))
       (8 '(@less))
       (9 '(@dup))
       (10 '(@drop))
       (11 '(@swap))
       (12 '(@over))
       (13 '(@rot))
       (14 :rpush (error "unsupported"))
       (15 :rpop (error "unsupported"))
       (16 :rcopy (error "unsupported"))
       (17 '(unresolve @jump))    ; アドレス解決が必要 (resolve-addrs関数内で行う)
       (18 '(unresolve @jump-if)) ; 同上
       (19 '(unresolve @call))    ; 同上
       (20 '(@return))))))

;; 読み込んだニーモニック(の中間形式)を、実行可能な(= cl-asm:executeに渡せる)に変換する
(defun convert-to-executable (mnemonics)
  (eval 
   `(body ,@(mapcar #'second (resolve-addrs mnemonics)) ; 本体
          (@pop %eax))))                                ; 結果を取り出して返す

;; 各種補助関数
(defun read-op (in)    ; バイト読み込み 
  (read-byte in nil nil))

(defun read-uint (in)  ; unsigned int読み込み
  (+ (ash (read-byte in) 00)
     (ash (read-byte in) 08)
     (ash (read-byte in) 16)
     (ash (read-byte in) 24)))

(defun read-int (in)   ; signed int読み込み
  (let ((n (read-uint in)))
    (if (< n #x80000000)
        n
      (- n #x100000000))))

(defun symb (&rest args)  ; シンボル生成: (symb "ABC" 1) => 'abc1
  (intern (format nil "~{~a~}" args)))

;; jump命令やcall命令が参照するアドレスをcl-asmが扱える形式に変換する
;; 
;; バイトコードでは遷移系の命令の直前に遷移先(絶対アドレス)が指定されているので、
;; mnemonics内の'((@int 10) (unresolve @call))のようになっている部分を '((@call &10)) のように置き換える。
;; ※ 変換時に生成したアドレス用のラベル(上の場合は'&10)は、最後にまとめてmnemonics内の適切な位置に挿入する。
(defun resolve-addrs (mnemonics)
  (labels ((recur (list acc addrs)
             (if (null list)
                 (values (nreverse acc) 
                         (remove-duplicates addrs))
               (let ((tag (first (second (car list)))))
                 (case tag
                   (unresolve 
                    (destructuring-bind ((_ (__ addr)) . acc2) acc
                      (declare (ignore _ __))
                      (let ((pos (first (car list)))
                            (op (second (second (car list)))))
                        (recur (cdr list) 
                               (cons `(,pos (,op ,(symb "&" addr))) acc2)
                               (cons addr addrs)))))
                   (otherwise
                    (recur (cdr list) (cons (car list) acc) addrs)))))))
    (multiple-value-bind (mnemonics refered-addrs)
                         (recur mnemonics '() '())
      (sort 
       (append mnemonics
               (loop FOR addr IN refered-addrs
                     COLLECT `(,(- addr 0.5) ,(symb "&" addr))))
       #'<
       :key #'first))))

resolve-addrs関数が若干複雑*1なことを除いては、バイトコードからのほぼ一対一の単純な変換となっている。

後は、前々回に合わせて実行部は通常のUnixコマンドとして使えるようにしておく。

;;; main関数作成用の補助関数
(eval-when (:compile-toplevel :load-toplevel :execute)
  ;; "/dir/file.ext" -> "file.ext"
  (defun basename (pathstring)
    (let ((path (parse-namestring pathstring)))
      (format nil "~A~@[.~A~]" (pathname-name path) (pathname-type path))))

  ;; '(a b c &optional c &key (d e)) -> '(a b c d)
  (defun collect-varsym (args)
    (mapcar (lambda (a)
	      (if (consp a) (car a) a))
	    (remove-if (lambda (a)
			 (and (symbolp a) (string= "&" a :end2 1)))
		       args))))

;;; main関数定義関数
(defmacro defmain (fn-name args &body body)
  (let ((usage nil))
    ;; If first expression of body is string type, it treated as command documentation
    (when (stringp (car body))
      (setf usage (car body)
	    body  (cdr body)))
    
    `(defun ,fn-name ()
       ;; Need to override *invoke-debugger-hook*
       (let ((sb-ext:*invoke-debugger-hook*
	      (lambda (condition hook)
		(declare (ignore hook))
		(format *error-output* "Error: ~A~%" condition)
		(sb-ext:quit :unix-status 1))))
         
	 ;; When failed arguments destructuring, show documentation and exit
	 ,(when usage
	    `(handler-case 
	      (destructuring-bind ,args (cdr sb-ext:*posix-argv*) 
	        (declare (ignore ,@(collect-varsym args))))
	      (error ()
	        (format *error-output* "~&~?~%~%" 
			,usage
			(list (basename (car sb-ext:*posix-argv*))))
		(sb-ext:quit :unix-status 1))))

         (destructuring-bind ,args (cdr sb-ext:*posix-argv*)
           ,@body
	   (sb-ext:quit :unix-status 0))))))

;;; main関数
;;; 引数で指定されたファイルパスに対してexecute関数を呼び出すだけ
(defmain main (bytecode-filepath)
  "Usage: ~a BYTECODE_FILEPTAH"
  (print (execute bytecode-filepath))
  (terpri))

;;; コマンド生成関数
(defun make-command (command-name)
  (sb-ext:save-lisp-and-die command-name :executable t :toplevel #'main))

コマンド生成&実行。

$ sbcl
> (load "pvm-execute.lisp")
> (pvme:make-command "pvm-jit")
[undoing binding stack and other enclosing state... done]
[saving current Lisp image into pvm-jit:
writing 6336 bytes from the read-only space at 0x20000000
writing 4000 bytes from the static space at 0x20100000
writing 46170112 bytes from the dynamic space at 0x1000000000
done]  ; pvm-jitコマンドが生成される

$ ./pvm-jit
Usage: pvm-jit BYTECODE_FILEPTAH

# フィボナッチ数計算
$ time ./pvm-jit fib.bc
9227465    # fib(35) = 9227465

real	0m0.169s
user	0m0.156s
sys	0m0.008s

# 前々回のコマンドの場合
$ time ./pvm fib.bc
[data stack]
 0# 9227465

[return stack]

real	0m3.636s
user	0m3.632s
sys	0m0.000s

比較

比較表に今回の結果を追記(pvm-jit)

言語 所要時間(最適化オプションなし) 所要時間(最適化オプションあり)
gcc-4.6.1 0.112s 0.056s
sbcl-1.0.54 0.320s 0.110s
pvm 3.600s
pvm-jit 0.156s
ruby1.9.1 2.816s
ruby1.8.7 14.497s
cl-asm 0.059s

不完全なアセンブラ及び最適化一切無しの単純な変換(バイトコード=>マシン語)という条件化でも、やはりインタプリタよりは桁違い(20倍程度)に速くなっている*2
データスタック操作周りで明らかに冗長な部分の最適化を簡単にでも行ったら、最適化オプション無しのgccになら結構すぐに追いつけるかもしれない。

*1:アドレス参照周りの仕様をなおざりにしすぎた・・・

*2:加えてVM部のソースコードも、インタプリタのものに比べて過度に複雑になっている、ということもない

アセンブリ言語でフィボナッチ数

前回は、C++で単純なVMを書いて、その上でのフィボナッチ数の計算時間を測定した。
そのVM部分をネイティブコードに置き換えたら、どの程度処理速度が改善するのかを測ってみたかったので、その前にまずネイティブコード(x86)の勉強も兼ねて、common lispアセンブラを書くことにした。
現状はまだまだ未完成で、以下のような制限があるが、一応フィボナッチ数が計算できるくらいまでには出来たので、その計算時間を参考までに残しておく。
制限:

  • 使用可能な命令は mov/ret/push/pop/add/sub/inc/dec/cmp/jmp/jcc/call のみ
  • 64bitのみ対応
  • エラーチェックとか不十分
  • SBCLのみで動作

github: cl-asm-0.0.1

コード

フィボナッチ数計算用のコード。

(use-package :sb-alien)

;; Fibonacci用のアセンブリコード
(defparameter *fib*
 '((:push %rbp) (:mov %rbp %rsp) (:push %rdi) (:push %rsi) (:push %rbx)  ; 関数呼び出し時の定形処理

   (:mov %eax %edi)  ; 引数取得
   (:call &fib-beg) 
   
   (:pop %rbx) (:pop %rsi) (:pop %rdi) (:pop %rbp)  ; 関数から返る時の定形処理
   :ret
  
  &fib-beg
  (:cmp %eax 2)      ; arg < 2
  (:jl &fib-end)
  
  (:push %rax)
  (:sub %eax 2)
  (:call &fib-beg)   ; x = (fib (- arg 2))
  (:pop %rbx)

  (:push %rax)
  (:mov %eax %ebx)
  (:dec %eax)
  (:call &fib-beg)   ; y = (fib (- arg 1))
  (:pop %rbx)
  
  (:add %eax %ebx)   ; (+ x y)
  &fib-end
  :ret))
--> *FIB*

;; 生成されるマシン語
(cl-asm:assemble *fib*)
--> (85 72 137 229 87 86 83 137 248 232 5 0 0 0 91 94 95 93 195 131 248 2 124 23 80
     131 232 2 232 242 255 255 255 91 80 137 216 255 200 232 231 255 255 255 91 1
     216 195)

;; 実行
(time
 (cl-asm:execute
   *fib*
  (function int int)  ; 関数の型
  35)                 ; 引数:  (fib 35)
Evaluation took:
  0.059 seconds of real time
  0.060003 seconds of total run time (0.060003 user, 0.000000 system)
  101.69% CPU
  117,246,804 processor cycles
  32,624 bytes consed
--> 9227465

比較

前回の他言語での測定結果に、上での計測結果を追加したもの。

言語 所要時間(最適化オプションなし) 所要時間(最適化オプションあり)
gcc-4.6.1 0.112s 0.056s
sbcl-1.0.54 0.320s 0.110s
pvm 3.600s
ruby1.9.1 2.816s
ruby1.8.7 14.497s
cl-asm 0.059s

やっぱりマシン語直出力は速い。最適化されたGCCよりは遅いけど。

簡易スタック型VM(バイトコードインタプリタ)でのフィボナッチ数計算速度

今年はlisp系のプログラミング言語(及びその処理系)を作ってみようと考えていて、かつ(少なくとも)当面の間はスタック型VMを基盤として実装していくことになると思われるので、まずは単純なスタックマシンのバイトコードインタプリタで、どの程度の処理速度がでるのかを計測してみた。

命令一覧と実行サンプル

現状のVMが備える命令一覧*1。必要最小限。
下記、命令セットに関してはForthを少し参考にしている。スタックマシンの動作の詳細に関しては、特に特殊な点もないので説明は割愛。

命令 コード値 in-stack out-stack 意味
int 1 n バイトコード中の後続の四バイト(little-endian)を取り出し、int値を生成
add 2 n1 n2 n3 n1 + n2
sub 3 n1 n2 n3 n1 - n2
mul 4 n1 n2 n3 n1 * n2
div 5 n1 n2 n3 n1 / n2
mod 6 n1 n2 n3 n1 % n2
eql 7 n1 n2 b(1 or 0) n1 == n2
less 8 n1 n2 b n1 < n2
dup 9 x x x スタックの先頭要素を複製
drop 10 x スタックの先頭要素を破棄
swap 11 x1 x2 x2 x1 スタックの先頭二つの要素を入れ替え
over 12 x1 x2 x1 x2 x1 スタックの二番目の要素を先頭に複製
rot 13 x1 x2 x3 x2 x3 x1 スタックの先頭三つの要素をローテーション
rpush 14 x スタック(データスタック)の先頭要素をリターンスタックの先頭に移す
rpop 15 x リターンスタックの先頭要素をスタックに移す
rcopy 16 x リターンスタックの先頭要素をスタックに複製
jump 17 n 無条件分岐。nは分岐先のアドレス
jump_if 18 b n 条件分岐。bが新(非ゼロ)なら分岐する
call 19 n 関数呼び出し。リターンスタックにプログラムカウンタを保存後、無条件分岐
return 20 関数からの復帰。リターンスタックからプログラムカウンタを取り出し、そこへ無条件分岐

末尾にソースコード全体を載せるが、バイトコードインタプリタの実行部は、バイトコードから上記命令に対応するコード値を取得し、命令を実行する、ということをひたすら繰り返すという単純なもの。

  // C++
  typedef unsigned char octet;

  /**
   * バイトコード実行用のクラス
   */
  class executor {
  public:
    void execute(const char* filepath) {
      bytecode_stream in(filepath);
      
      // バイトコードストリームの終端に達するまでループ
      while(in.eos() == false) {
        octet opcode = in.read_octet();  // 命令コード読み出し
        op::call(opcode, in, env);       // コードに対応する処理を実行 (envにはデータスタックとリターンスタックが保持されている)
      }
    }
  };

  class op {
  public:
    // コードに対応する命令を実行
    static void call(octet opcode, bytecode_stream& in, environment& env) {
      switch(opcode) {
      case  1: op_int(in, env); break; // int値構築
      case  2: op_add(in, env); break; // +
      case  3: op_sub(in, env); break; // -
      case  4: op_mul(in, env); break; // *
      case  5: op_div(in, env); break; // /
      case  6: op_mod(in, env); break; // %
      case  7: op_eql(in, env); break; // ==
      ... 省略 ...
        
      default:
        assert(false);
      }
    }
  }

VM部はC++で記述しているが、VMが解釈可能なバイトコード列を生成するためのアセンブラ(コンパイラ)はcommon lispで作成。

;; common lisp
;; 実行例
(load "pvm-compile")

;; 加算を行うバイトコード列を'add.bc'ファイルに出力する
;;  - キーワードは命令を表す
(pvmc:compile-to-file
 "add.bc"
 '(10 20 :add))  ; 10 + 20

;; 条件分岐を行うバイトコード列を'jump.bc'ファイルに出力する
;;
;; シンボルはアドレス参照用のラベルを表す
;; (:addr シンボル)形式で参照可能
;; ※ アドレスはコンパイル時に解決される
(pvmc:compile-to-file
 "jump.bc"
 '(10 10 :eql            ; n1 == n2 ?
   (:addr then) :jump-if ; 等しいなら then に移動
   else
   1 2     ; else: スタックに 1と2 を積む
   (:addr end) :jump
   then 
   3 4    ; then: スタックに 3と4 を積む
   end))

;; 上の例では以下のようなバイト列が生成される
(pvmc::compile-to-bytecodes
 '(10 10 :eql (:addr then) :jump-if else 1 2 (:addr end) :jump then 3 4 end))
 => #(1 10 0 0 0 1 10 0 0 0 7 1 33 0 0 0 18 1 1 0 0 0 1 2 0 0 0 1 43 0 0 0 17 1 3 0
      0 0 1 4 0 0 0)

生成したバイトコードはpvmコマンドで実行可能。

# pvmコマンド作成
$ g++ -O2 -o pvm pvm.cc

# add.bc
$ ./pvm add.bc
[data stack]    # 実行後のデータスタックとリターンスタックの中身が出力される
 0# 30   # 10 + 20

[return stack]

# jump.bc
$ ./pvm jump.bc
[data stack]
 0# 4     # then部が実行された
 1# 3

[return stack]

実行速度

上のVM上でのフィボナッチ数の計算に要した時間。
以下は35のフィボナッチ数計算用のコード。

(pvmc:compile-to-file
 "fib.bc"
 '(
   35    ; fib(35)
   (:addr fib-beg) :call ; fib(25)
   (:addr finish)  :jump
   
   fib-beg
   :dup 2  :less (:addr fib-end) :jump-if  ; if(n < 2) 
   :dup 2  :sub  (:addr fib-beg) :call     ; fib(n - 2)
   :swap 1 :sub  (:addr fib-beg) :call     ; fib(n - 1)
   :add
   fib-end
   :return
   
   finish))

#| 実行結果:
$ time ./pvm fib.bc 
[data stack]
 0# 9227465

[return stack]


real	0m3.605s
user	0m3.600s
sys	0m0.000s
|#

他言語との比較。

言語 所要時間(最適化オプションなし) 所要時間(最適化オプションあり)
gcc-4.6.1 0.112s 0.056s
sbcl-1.0.54 0.320s 0.110s
pvm 3.600s
ruby1.9.1 2.816s
ruby1.8.7 14.497s

現状は本当に単純なインタプリタなので仕方がないとはいえ、Ruby(1.9)よりも遅い・・・。

ちなみに各言語用のソースコードは以下の通り。

// C++
// ファイル名: fib.cc
// コンパイル: g++ -O2 -o fib fib.cc
// 実行: time fib 35
#include <iostream>
#include <cstdlib>

int fib(int n) {
  if(n < 2) {
    return n;
  }
  return fib(n-2) + fib(n-1);
}

int main(int argc, char** argv) {
  std::cout << fib(atoi(argv[1])) << std::endl;
  return 0;
}
;; sbcl
(defun fib (n)
  (declare (optimize (speed 3) (safety 0) (debug 0))
           (fixnum n))
  (if (< n 2)
      n
    (the fixnum (+ (fib (- n 2)) (fib (- n 1))))))

;; 実行
(time (fib 35))
# ruby
# ファイル名: fib.rb
# 実行: time fib.rb 35
def fib (n)
  return n if n < 2
  fib(n-2) + fib(n-1)
end

p fib(ARGV[0].to_i)

ソースコード

VM及びコンパイラ用のソースコード
それぞれ180行、80行程度。

// ファイル名: pvm.hh
/**
 * バイトコードインタプリタ
 */
#ifndef PVM_HH
#define PVM_HH

#include <iostream>
#include <fstream>
#include <cassert>
#include <vector>
#include <algorithm>

namespace pvm {
  typedef unsigned char octet;
  typedef std::vector<int> stack_t;

  
  /**
   * バイトコード読み込みストリーム
   */
  class bytecode_stream {
  public:
    bytecode_stream(const char* filepath) : bytecodes(NULL), position(0) {
      std::ifstream in(filepath);
      assert(in);

      length = in.rdbuf()->in_avail();
      bytecodes = new octet[length];
      in.read((char*)bytecodes, length);
    }
    
    ~bytecode_stream() { delete [] bytecodes; }
    
    bool eos() const { return position >= length; }
    
    octet read_octet () { return bytecodes[position++]; }

    // sizeof(int) == 4 と仮定
    int read_int() {
      int n = *(int*)(bytecodes + position);
      position += 4;
      return n;
    }

    // program counter
    unsigned pc() const { return position; }
    unsigned& pc() { return position; }
    
  private:
    octet* bytecodes;
    unsigned length;
    unsigned position;
  };


  /**
   * データスタックとリターンスタック
   */
  class environment {
  public:
    stack_t& dstack() { return data_stack; }
    stack_t& rstack() { return return_stack; }

    const stack_t& dstack() const { return data_stack; }
    const stack_t& rstack() const { return return_stack; }

  private:
    stack_t data_stack;
    stack_t return_stack;
  };


  /**
   * 各種操作(命令)
   */
  class op {
  public:
    static void call(octet opcode, bytecode_stream& in, environment& env) {
      switch(opcode) {
      case  1: op_int(in, env); break; // int値構築
      case  2: op_add(in, env); break; // +
      case  3: op_sub(in, env); break; // -
      case  4: op_mul(in, env); break; // *
      case  5: op_div(in, env); break; // /
      case  6: op_mod(in, env); break; // %
      case  7: op_eql(in, env); break; // ==
      case  8: op_less(in, env); break;// <

      case  9: op_dup(in, env); break;  // データスタックの先頭要素を複製
      case 10: op_drop(in, env); break; // データスタックの先頭要素を破棄
      case 11: op_swap(in, env); break; // データスタックの最初の二つの要素を入れ替え
      case 12: op_over(in, env); break; // データスタックの二番目の要素を先頭にコピーする
      case 13: op_rot(in, env); break;  // データスタックの先頭三つの要素をローテーションする
        
      case 14: op_rpush(in, env); break; // データスタックの先頭要素を取り出しリターンスタックに追加する
      case 15: op_rpop(in, env); break;  // リターンスタックの先頭要素を取り出しデータスタックに追加する
      case 16: op_rcopy(in, env); break; // リターンスタックの先頭要素をデータすタックに追加する

      case 17: op_jump(in, env); break;    // 無条件分岐
      case 18: op_jump_if(in, env); break; // 条件分岐
      case 19: op_call(in, env); break;    // 関数呼び出し
      case 20: op_return(in, env); break;  // 関数から復帰
        
      default:
        assert(false);
      }
    }

  private:
    typedef bytecode_stream bcs;
    typedef environment env;
    
#define DPUSH(x) e.dstack().push_back(x)
#define DPOP pop_back(e.dstack())
#define DNTH(nth) e.dstack()[e.dstack().size()-1-nth]

#define RPUSH(x) e.rstack().push_back(x)
#define RPOP pop_back(e.rstack())
#define RNTH(nth) e.rstack()[e.rstack().size()-1-nth]

    static void op_int(bcs& in, env& e) { DPUSH(in.read_int()); }
    static void op_add(bcs& in, env& e) { DPUSH(DPOP + DPOP); }
    static void op_sub(bcs& in, env& e) { int n = DPOP; DPUSH(DPOP - n); }
    static void op_mul(bcs& in, env& e) { DPUSH(DPOP * DPOP); }
    static void op_div(bcs& in, env& e) { int n = DPOP; DPUSH(DPOP / n); }
    static void op_mod(bcs& in, env& e) { int n = DPOP; DPUSH(DPOP % n); }
    static void op_eql(bcs& in, env& e) { DPUSH(DPOP == DPOP); }
    static void op_less(bcs& in, env& e) { DPUSH(DPOP > DPOP); }

    static void op_dup(bcs& in, env& e) { DPUSH(DNTH(0)); }
    static void op_drop(bcs& in, env& e) { DPOP; }
    static void op_swap(bcs& in, env& e) { std::swap(DNTH(0), DNTH(1)); }
    static void op_over(bcs& in, env& e) { DPUSH(DNTH(1)); }
    static void op_rot(bcs& in, env& e) { std::swap(DNTH(2), DNTH(0)); std::swap(DNTH(1), DNTH(2)); }

    static void op_rpush(bcs& in, env& e) { RPUSH(DPOP); }
    static void op_rpop(bcs& in, env& e) { DPUSH(RPOP); }
    static void op_rcopy(bcs& in, env& e) { DPUSH(RNTH(0)); }

    static void op_jump(bcs& in, env& e) { in.pc() = DPOP;}
    static void op_jump_if(bcs& in, env& e) { int p = DPOP; if(DPOP){ in.pc() = p;} }
    static void op_call(bcs& in, env& e) { RPUSH(in.pc()); in.pc() = DPOP; }
    static void op_return(bcs& in, env& e) { in.pc() = RPOP; }

#undef DPUSH
#undef DPOP
#undef DNTH

#undef RPUSH
#undef RPOP
#undef RNTH

  private:
    static int pop_back(stack_t& stack) {
      int x = stack.back();
      stack.pop_back();
      return x;
    }
  };


  /**
   * バイトコード実行
   */
  class executor {
  public:
    void execute(const char* filepath) {
      bytecode_stream in(filepath);
      
      while(in.eos() == false) {
        octet opcode = in.read_octet();
        op::call(opcode, in, env);
      }
    }
    
    const environment& get_env() const { return env; }

  private:
    environment env;
  };
}

#endif
// ファイル名: pvm.cc
// バイトコード実行用コマンド
#include "pvm.hh"
#include <iostream>

void show(const char* name, const pvm::stack_t& stack) {
  std::cout << "[" << name << "]" << std::endl;
  for(int i = stack.size()-1; i >= 0; i--) {
    std::cout << " " << (stack.size()-1-i) << "# " << stack[i] << std::endl;
  }
  std::cout << std::endl;  
}

int main(int argc, char** argv) {
  if(argc != 2) {
    std::cerr << "Usage: pvm BYTECODE_FILEPATH" << std::endl;
    return 1;
  }
  
  pvm::executor vm;
  vm.execute(argv[1]);

  const pvm::environment& rlt = vm.get_env();
  show("data stack", rlt.dstack());
  show("return stack", rlt.rstack());

  return 0;
}
;;; ファイル名: pvm-compile.lisp
;;; S式をVM用のバイトコードにコンパイル(アセンブル)する
(defpackage pvm-compile
  (:use :common-lisp)
  (:nicknames :pvmc)
  (:export compile-to-file))
(in-package :pvm-compile)

;; 利用可能な操作(命令)のリスト
(defparameter *op-list*
  '((1 :int)
    (2 :add)
    (3 :sub)
    (4 :mul)
    (5 :div)
    (6 :mod)
    (7 :eql)
    (8 :less)

    (9 :dup)
    (10 :drop)
    (11 :swap)
    (12 :over)
    (13 :rot)

    (14 :rpush)
    (15 :rpop)
    (16 :rcopy)
    
    (17 :jump)
    (18 :jump-if)
    (19 :call)
    (20 :return)))

;; 数値をリトルエンディアンのバイト列に変換する
;; n -> '(b1 b2 b3 b4)
(defun int-to-bytes (n)
  (loop FOR i FROM 0 BELOW 4
        COLLECT (ldb (byte 8 (* i 8)) n)))

;; 操作名に対するコード値を取得する
(defun opcode (op)
  (assert #1=(find op *op-list* :key #'second))
  (first #1#))

;; コンパイル
(defun compile-to-bytecodes (codes)
  (loop WITH unresolves = '()  ; 未解決のアドレス参照
        WITH labels = '()      ; ラベルとアドレスのマッピング
        FOR code IN codes
        FOR pos = (length tmps)
    APPEND
    (etypecase code
      (integer `(,(opcode :int) ,@(int-to-bytes code))) ; 整数値構築
      (keyword (list (opcode code)))                    ; 一般的な操作
      (symbol (push `(,code ,pos) labels)               ; アドレス(PC)参照用のラベル
              '())
      (cons (ecase (first code)                         ; アドレス参照
              (:addr (push `(,(second code) ,(1+ pos)) unresolves)
                     `(,(opcode :int) 0 0 0 0))))) ; この時点では実際のアドレスが不明なので 0 を設定しておく
    INTO tmps
    FINALLY
    (let ((bytecodes (coerce tmps 'vector)))
      ;; アドレス解決
      (loop FOR (label offset) IN unresolves
            FOR label-addr = (second (assoc label labels))
        DO
        (setf (subseq bytecodes offset (+ offset 4)) (int-to-bytes label-addr)))

      (return bytecodes))))

;; コンパイルして結果をファイルに出力する
(defun compile-to-file (filepath assembly-codes)
  (let ((bytecodes (compile-to-bytecodes assembly-codes)))
    (with-open-file (out filepath :direction :output
                                  :if-exists :supersede
                                  :element-type '(unsigned-byte 8))
      (write-sequence bytecodes out)))
  t)

*1:大別すると整数処理系、データスタック操作系、リターンスタック操作系、分岐系の四つ