読者です 読者をやめる 読者になる 読者になる

ループ処理を関数型っぽく書いてみる(2)

common lisp speed optimize

前回の続き。
githubにあるloopの簡易版を載せておく。

基本的な考え方

基本的なJava等のIteratorと似た*1インタフェースを通してループ処理を実現している。
異なるのは全ての関数をinline展開可能にすることで、同等のループを非関数型的に書いた場合と同じくらいに、コンパイラが最適化を行ってくれることを期待していることくらい。
後は、SBCLの最適化の制限上、構造体等は使用せず、極力lambdaで全てを表現するようにしている。

実装

まず、loopパッケージ用のシーケンス生成関数。

;; 数値の範囲を表現するシーケンス
(declaim (inline from))
(defun from (start &key to (by 1))  ; toがnilなら無限シーケンス
  ;; 全体をlambdaで囲む。このlambdaの呼び出しがシーケンスの初期化処理に相当する。
  (lambda () 
    (let ((cur start))
      ;; 以下の三つの関数を呼び出し元に返す
      (values (lambda () (incf cur by))          ; 1] 値更新関数
              (lambda () (and to(> cur to)))     ; 2] 終端判定関数
              (lambda (fn) (funcall fn cur)))))) ; 3] ループの本体実行関数

;; リスト用
(declaim (inline for-list))
(defun for-list (list)
  (lambda ()
    (let ((head list)) ; 初期値
      (values (lambda () (setf head (cdr head)))         ; 1] 値更新関数
              (lambda () (endp head))                    ; 2] 終端判定関数
              (lambda (fn) (funcall fn (car head)))))))  ; 3] ループの本体実行関数

;; 実行
> (from 1 :to 10)
#<CLOSURE (LAMBDA () :IN FROM) {1007855D3B}>

> (funcall (from 1 :to 10))
#<CLOSURE (LAMBDA () :IN FROM) {10078C432B}>   ; 1] 値更新関数
#<CLOSURE (LAMBDA () :IN FROM) {10078C434B}>   ; 2] 終端判定関数
#<CLOSURE (LAMBDA (FN) :IN FROM) {10078C436B}> ; 3] ループの本体実行関数

> (setf (values next end? call-body) (funcall (from 1 :to 10)))
> (funcall call-body (lambda (x) (list :val x)))
=> (:val 1)

> (funcall next)
> (funcall call-body (lambda (x) (list :val x)))
=> (:val 2)

上の関数で生成されたシーケンスを走査する関数。

;; 一番基本となる走査関数
(declaim (inline each))
(defun each (fn seq)
  (multiple-value-bind (next-fn end-fn call-fn) (funcall seq)  ; シーケンス初期化
    (loop UNTIL (funcall end-fn)   ; 終端判定
          DO (funcall call-fn fn)  ; 本体実行
             (funcall next-fn))))  ; 値更新

;; 畳み込み関数  ※ reduceはclパッケージとそれと名前が衝突するので、ここではfoldにしている
(declaim (inline fold))
(defun fold (fn init seq)
  (let ((acc init))
    (each (lambda (x)
            (setf acc (funcall fn acc x)))
          seq)
    acc))

;; シーケンスを集めたリストを返す
(declaim (inline collect))
(defun collect (seq)
  (nreverse (fold (lambda (acc x) (cons x acc))
                    '()
                    seq)))

;; 実行
> (collect (from 1 :to 20 :by 3))
=> (1 4 7 10 13 16 19)

; 合計値計算
> (fold (lambda (acc x) (+ acc x))
        0
        (from 1 :to 20 :by 3))
=> 70

mapとかfilterとかシーケンスを加工/制御する関数。

;; map関数
(declaim (inline map-seq))
(defun map-seq (map-fn seq)
  ;; ソースとなるシーケンスの情報を取得し、それをラップして返す
  (multiple-value-bind (next-fn end-fn call-fn) (funcall seq)
    (lambda ()
      (values next-fn  ; 値更新関数と終端判定関数はそのまま
              end-fn
              (lambda (body-fn)
                ;; 本体呼び出し前に、マップ処理用関数を差し込む
                (funcall call-fn (lambda (val) (funcall body-fn (funcall map-fn val)))))))))

;; filter関数: (funcall pred-fn val)がnilとなる要素をスキップする
(declaim (inline filter))
(defun filter (pred-fn seq)
  ;; ソースとなるシーケンスの情報を取得し、それをラップして返す
  (multiple-value-bind (next-fn end-fn call-fn) (funcall seq)
    (lambda ()
      (values next-fn ; 値更新関数と終端判定関数はそのまま
              end-fn
              (lambda (body-fn)
                ;; 本体呼び出し前に、フィルター処理用関数を差し込む
                (funcall call-fn (lambda (val)
                                   (unless (funcall pred-fn val)
                                     (funcall body-fn val)))))))))

;; 実行
; 二乗する
> (collect (map-seq (lambda (x) (* x x)) (from 1 :to 20)))
=> (1 4 9 16 25 36 49 64 81 100 121 144 169 196 225 256 289 324 361 400)

; 奇数の値だけフィルタして二乗する
> (collect (map-seq (lambda (x) (* x x)) (filter #'oddp (from 1 :to 20))))
=> (1 9 25 49 81 121 169 225 289 361)

これらの関数群を組み合わせてループ処理を表現すると、そこそこ良い感じのコードを生成してくれる。

;; 上で定義した関数群を用いたsum関数
;; - startからendの範囲の奇数値を-10した合計値を返す
(defun sum1 (start end)
 (declare (fixnum start end)
          (optimize (speed 3) (safety 0))
          (sb-ext:unmuffle-conditions sb-ext:compiler-note))
 (fold (lambda (acc n)
         (the fixnum (+ acc n)))
       0
       (map-seq (lambda (x) (- x 10)) 
                (filter #'oddp (from start :to end)))))

;; loopマクロを使用したsum関数
(defun sum2 (start end)
 (declare (fixnum start end)
          (optimize (speed 3) (safety 0))
          (sb-ext:unmuffle-conditions sb-ext:compiler-note))
 (loop WITH total fixnum = 0
       FOR i FROM start TO end
       WHEN (oddp i)
   DO (let ((n (- i 10)))
        (declare (fixnum n))
        (incf total n))
   FINALLY (return total)))

;; 一億要素に対するループ
> (time (sum1 1 100000000))
Evaluation took:
  0.134 seconds of real time  ; 0.134秒
  0.136009 seconds of total run time (0.136009 user, 0.000000 system)
  101.49% CPU
  267,335,373 processor cycles
  0 bytes consed
=> 2499999500000000

> (time (sum2 1 100000000))
Evaluation took:
  0.131 seconds of real time  ; 0.131秒
  0.132008 seconds of total run time (0.132008 user, 0.000000 system)
  100.76% CPU
  261,630,697 processor cycles
  0 bytes consed
=> 2499999500000000

;; disassemble結果
> (disassemble #'sum1)
; disassembly for SUM1
; 07AA1C28:       31D2             XOR EDX, EDX               ; no-arg-parsing entry point
;       2A:       EB1B             JMP L2
;       2C:       90               NOP
;       2D:       90               NOP
;       2E:       90               NOP
;       2F:       90               NOP
;       30: L0:   488BC1           MOV RAX, RCX
;       33:       488D1C4500000000 LEA RBX, [RAX*2]
;       3B:       4883E302         AND RBX, 2
;       3F:       4885DB           TEST RBX, RBX
;       42:       750E             JNE L3
;       44: L1:   48FFC1           INC RCX
;       47: L2:   4839F9           CMP RCX, RDI
;       4A:       7EE4             JLE L0
;       4C:       488BE5           MOV RSP, RBP
;       4F:       F8               CLC
;       50:       5D               POP RBP
;       51:       C3               RET
;       52: L3:   4883E80A         SUB RAX, 10
;       56:       48D1FA           SAR RDX, 1
;       59:       4801C2           ADD RDX, RAX
;       5C:       48D1E2           SHL RDX, 1
;       5F:       EBE3             JMP L1

> (disassemble #'sum2)
; disassembly for SUM2
; 07EF0DB8:       31D2             XOR EDX, EDX               ; no-arg-parsing entry point
;       BA:       EB2A             JMP L2
;       BC:       90               NOP
;       BD:       90               NOP
;       BE:       90               NOP
;       BF:       90               NOP
;       C0: L0:   488D044D00000000 LEA RAX, [RCX*2]
;       C8:       4883E002         AND RAX, 2
;       CC:       4885C0           TEST RAX, RAX
;       CF:       7412             JEQ L1
;       D1:       488D044D00000000 LEA RAX, [RCX*2]
;       D9:       488BD8           MOV RBX, RAX
;       DC:       4883EB14         SUB RBX, 20
;       E0:       4801DA           ADD RDX, RBX
;       E3: L1:   48FFC1           INC RCX
;       E6: L2:   4839F9           CMP RCX, RDI
;       E9:       7ED5             JLE L0
;       EB:       488BE5           MOV RSP, RBP
;       EE:       F8               CLC
;       EF:       5D               POP R

最後は複数シーケンスをまとめるzip関数。
これを使うと表現力はだいぶ上がるけど、性能は若干劣化する。

;; 二つのシーケンスをまとめる
(declaim (inline zip))
(defun zip (loop1 loop2 &aux (undef (gensym)))
  (multiple-value-bind (next-fn1 end-fn1 call-fn1) (funcall loop1)
    (multiple-value-bind (next-fn2 end-fn2 call-fn2) (funcall loop2)
      (let ((memo1 undef)
            (memo2 undef))
        (lambda ()
          (values (lambda () ; 値更新
                    (when (eq memo1 undef) (funcall next-fn1))
                    (when (eq memo2 undef) (funcall next-fn2)))

                  (lambda () ; 終端判定
                    (or (funcall end-fn1) (funcall end-fn2)))

                  (lambda (body-fn)  ; 本体呼び出し
                    ;; それぞれのシーケンスの次の値を取得する
                    ;; (次の値がfilterでスキップされた場合は memoX はundefのままになる)
                    (when (eq memo1 undef)
                      (funcall call-fn1 (lambda (val) (setf memo1 val))))

                    (when (eq memo2 undef)
                      (funcall call-fn2 (lambda (val) (setf memo2 val))))
    
                    ;; 両方のシーケンスの値が取得できたら、本体を呼び出す
                    (when (not (or (eq memo1 undef)
                                   (eq memo2 undef)))
                      (funcall fn (list memo1 memo2))  ; XXX: listで二つの値をまとめるのはconsingが発生するので効率が悪い (そのためloopパッケージでは、多引数を受け取るmapやfilterを用意している)
                      (setf memo1 undef
                            memo2 undef)))))))))

;; 実行
> (collect
    (zip (filter (lambda (n) (and (oddp n)  (zerop (mod n 3)))) (from 1))           ; 奇数かつ三の倍数
         (filter (lambda (n) (and (evenp n) (zerop (mod n 5)))) (from 1 :to 100)))) ; 偶数かつ五の倍数
=> ((3 10) (9 20) (15 30) (21 40) (27 50) (33 60) (39 70) (45 80) (51 90) (57 100))

zipはもう少し上手く実装したいところだけど、それでも関数型っぽく書いても実用上十分な性能がでるループ処理が実現できそうなことが分かったので、結構満足している。

*1:似てないかも