llvm : tutorial : optimize

チュートリアル第四章『4. Kaleidoscope: Adding JIT and Optimizer Support — LLVM 3.4 documentation』。
タイトルの通り、ここではJIT(を用いたRead-Eval-Print-Loop)とoptimizeがKaleidoscopeに加わっている。
前者は少しやっかいなので、分割して今回は後者だけを扱うことにする。

optimize実現方法

本章では、C++のIRBuilderとFunctionPassManagerというクラスを組み合わせて最適化を行う方法が説明されている。
ただ-前回と同様に-common lisp向けには、そういった便利なモジュールは提供されていないので、代わりにllvmのoptというコマンドを利用して最適化を行うことにする。※ optコマンドに関しては『opt - LLVM optimizer — LLVM 3.4 documentation』を参照

以下、optコマンドの使用例。
※ 下の例で行われている最適化がどのようなものかは、オリジナルのチュートリアルを参照のこと

######
# 補足
#  - llvm-as:  アセンブラ   (llvmアセンブリ言語 ==> ビットコード)
#  - llvm-dis: 逆アセンブラ (ビットコード ==> llvmアセンブリ言語)

#########
## 例1 ##
$ llvm-as | opt -O3 | llvm-dis
define double @test(double %x) {
entry:
        %addtmp = add double 2.000000e+00, 1.000000e+00
        %addtmp1 = add double %addtmp, %x
        ret double %addtmp1
}
^D  # ここまでが入力

# 最適化された出力
define double @test(double %x) nounwind readnone {
entry:
  %addtmp1 = fadd double %x, 3.000000e+00         ; <double> [#uses=1]
  ret double %addtmp1
}

#########
## 例2 ##
$ llvm-as | opt -O3 | llvm-dis
define double @test(double %x) {
entry:
        %addtmp = add double 3.000000e+00, %x
        %addtmp1 = add double %x, 3.000000e+00
        %multmp = mul double %addtmp, %addtmp1
        ret double %multmp
}
^D  # ここまでが入力

# 最適化された出力
; ModuleID = '<stdin>'

define double @test(double %x) nounwind readnone {
entry:
  %addtmp = fadd double %x, 3.000000e+00          ; <double> [#uses=2]
  %multmp = fmul double %addtmp, %addtmp          ; <double> [#uses=1]
  ret double %multmp
}

今回実質的に行うことは、上の例にある三つのコマンドに、lispが生成したllvmのコードを流すだけ。

実装

sbclのsb-grayパッケージを使って、普通の出力ストリームをラップした、最適化出力ストリームを作成する。
このラッパーストリームは、コードジェネレーターが生成/出力したllvmのコードをまず`llvm-as | opt -O3 | llvm-dis`コマンドに渡して最適化を行い、その結果をラップされたストリームに出力している。

;;;; llvmコードのoptimize用のラッパーストリーム定義
;;;; 以下のURLを参考に、必要最小限の部分だけを修正/実装している
;;;;  - http://www.sbcl.org/manual/Output-prefixing-character-stream.html#Output-prefixing-character-stream

(defclass llvm-opt-stream (sb-gray:fundamental-stream)
  ((stream :initarg :stream :reader stream-of)))

(defmethod stream-element-type ((stream llvm-opt-stream))
  (stream-element-type (stream-of stream)))

(defmethod close ((stream llvm-opt-stream) &key abort)
  (close (stream-of stream) :abort abort))

(defclass llvm-opt-character-output-stream
  (llvm-opt-stream sb-gray:fundamental-character-output-stream)
  ((col-index :initform 0   :accessor col-index-of)
   (buf       :initform '() :accessor buf-of)))

(defmethod stream-line-column ((stream llvm-opt-character-output-stream))
  (col-index-of stream))

;; 入力のllvmコード(input-program)に対して最適化を行い、
;; 結果を文字列として返す。
;; 最適化中にエラーが発生した場合は、入力をそのまま返す。
(defun do-optimize (input-program)
  (with-input-from-string (in input-program)
    (with-output-to-string (out)
      (let ((proc 
             (sb-ext:run-program "sh" '("-c" "llvm-as | opt -O3 | llvm-dis") 
                                 :search t :output out :input in)))
      (unless (zerop (sb-ext:process-exit-code proc))
        ;; コマンドの実行に失敗した場合
        (return-from do-optimize input-program))))))
        
(defmethod stream-write-char ((stream llvm-opt-character-output-stream) char)
  (with-accessors ((inner-stream stream-of) 
                   (cols col-index-of)
                   (buf  buf-of)) stream
    ;; ひとまとまりのllvmのコードが揃うまで、出力文字をバッファリングする
    ;; ※ 現状は、簡単のために'}'(関数定義の終了)が現れるまで、バッファリングを行っている
    (push char buf)
    (when (char= char #\})
      ;; コードを最適化して、その結果を出力する
      (princ (do-optimize (coerce (nreverse buf) 'string))
             inner-stream)
      (setf buf '()))
    (if (char= char #\Newline)
        (setf cols 0)
      (incf cols))))

合わせてmain-loop関数も少し修正。

(defun main-loop (&optional (*print-pprint-dispatch* *print-pprint-dispatch*)
                            (*standard-output* *standard-output*)) ; 出力ストリームを引数に追加
  (loop
   (princ "ready> " *error-output*) (force-output *error-output*)  ; プロンプトは*error-output*に
   (handler-case
    (format t "~&~S~%"
      (case (read-token)
        (#\;       )   
        (:def      (parse-definition))
        (:extern   (parse-extern))
        (otherwise (parse-toplevel-exp))))
    (end-of-file ()
      (return-from main-loop))
    (error (c)
      (read-token)
      (format *error-output* "~&Error: ~A~%" c)))))

実行。

> (defvar *opt-output-stream* 
    (make-instance 'llvm-opt-character-output-stream
                   :stream *standard-output*))

;;;;;;;;;;;;;;;;;;;
;; 最適化しない場合
> (main-loop *llvm-pp-table*)
ready> def test(x) (1+2+x)*(x+(1+2));

define double @test (double %x) {
entry:
%addtmp3 = fadd double 1.e+0, 2.e+0
%addtmp4 = fadd double %addtmp3, %x
%addtmp5 = fadd double 1.e+0, 2.e+0
%addtmp6 = fadd double %x, %addtmp5
%multmp7 = fmul double %addtmp4, %addtmp6
ret double %multmp7
}
ready> ^D

;;;;;;;;;;;;;;;;;
;; 最適化する場合
> (main-loop *llvm-pp-table* *opt-output-stream*)
ready> def test(x) (1+2+x)*(x+(1+2));
; ModuleID = '<stdin>'

define double @test(double %x) nounwind readnone {
entry:
  %addtmp9 = fadd double %x, 3.000000e+00         ; <double> [#uses=2]
  %multmp12 = fmul double %addtmp9, %addtmp9      ; <double> [#uses=1]
  ret double %multmp12
}

以上。