llvm : tutorial : code generation

llvmチュートリアルの続き。
今回はパースして得られた抽象構文木から、llvm(llvm IRの?)コードを生成する。

概要?

オリジナルのチュートリアルでは、コード生成にC++のビルダクラスが使われているが、common lisp用のそういったモジュールは用意されていないので、必要な分は適宜実装していくことにする。
また、せっかくcommon lispを使うので(?)で、pretty printを利用してコードを生成してみることにする。とはいっても大して活用はしていない。

ソースコード

まず最初は、前回の実装の補足から。

;; extern用の構造体(名前付きリスト)を追加
(defstruct (extern (:type list) :named) proto)

;; extern宣言をパースした時は、extern構造体でラップするようにする
;; ※ pretty printでディスパッチできるように
(defun parse-extern ()
  (read-token)
  (make-extern :proto (parse-prototype)))

;; それぞれのAST用の型を定義する
(deftype number-exp   () '(satisfies number-exp-p))
(deftype variable-exp () '(satisfies variable-exp-p))
(deftype binary-exp   () '(satisfies binary-exp-p))
(deftype call-exp     () '(satisfies call-exp-p))
(deftype prototype    () '(satisfies prototype-p))
(deftype define       () '(satisfies define-p))
(deftype extern       () '(satisfies extern-p))

次は、コード生成用のppディスパッチ表の準備と各種補助関数の定義。

;; llvmコード生成用のディスパッチ表を用意する
;; XXX: copy-pprint-dispatchの引数としてnilが適切かどうかは不明
(defvar *llvm-pp-table* (copy-pprint-dispatch nil))

;;;;;;;;;;;;;;;;
;;; 補助関数など
(defvar *return-value*)  ; コード生成(pretry print)関数の戻り値を保存しておくための変数
(defun gen-code (ast)    ; コード生成 ==> *return-value*を返す
  (princ ast)
  *return-value*)

;; コード生成関数作成補助マクロ
;; 主に次の三つのことを行う
;;   1] set-pprint-dispatchの簡略化
;;   2] コード生成関数の第一引数(出力ストリーム)を常に*standard-output*にバインドする
;;   3] コード生成関数の戻り値を、*return-value*に格納する
(defmacro def-llvm-ppd (type args &body body)
  (setf args (cons 'common-lisp:*standard-output* args))
  `(set-pprint-dispatch
    ',type
    (lambda ,args 
      (declare (ignorable common-lisp:*standard-output*))
      (setf *return-value* (progn ,@body)))
    0 
    *llvm-pp-table*))

;; llvmの識別子(シンボル)を作成する
;;   ローカル識別子:   %名前
;;   グローバル識別子: @名前
(defun llvm-sym (name &optional (type :local) temp)
  (check-type type (member :local :global))
  (format nil "~A~A" 
          (if (eq type :global) "@" "%")
          (if temp (gentemp name) name)))

;; prog1のアナフォリック版
(defmacro a.prog1 (exp1 &body rest-exps)
  `(let ((it ,exp1))
     (prog1 it
       ,@rest-exps)))

後は、コード生成関数の定義。

;;;;;;;;
;;; 数値
;; 全てdouble型に変換する
(def-llvm-ppd number-exp (exp)
  (format nil "~E" (number-exp-val exp)))

;;;;;;;;
;;; 変数
;; 変数の検索/束縛/初期化関数
;; ※ 変数は、関数の引数にのみ現れる
(let ((binding-variables '()))
  (defun find-var (varname) (find varname binding-variables :test #'equal))
  (defun bind-var (varname) (pushnew varname binding-variables :test #'equal))
  (defun clear-vars ()      (setf binding-variables '())))

(def-llvm-ppd variable-exp (exp)
  (if (find-var #1=(variable-exp-name exp))
      ;; 束縛済みの変数が見つかった場合は、それをllvmの識別子に変換して返す
      (llvm-sym #1# :local)
    (error "Unknown variable name: ~A" #1#)))

;;;;;;;;;;;;;;
;;; 二項演算子
;; コード生成: add, sub, mul, <, int=>double変換
(defun pp-add (x y name)
  (a.prog1 (llvm-sym name :local t)
    (format t "~&~A = fadd double ~A, ~A~%" it x y)))
(defun pp-sub (x y name)
  (a.prog1 (llvm-sym name :local t)
    (format t "~&~A = fsub double ~A, ~A~%" it x y)))
(defun pp-mul (x y name)
  (a.prog1 (llvm-sym name :local t)
    (format t "~&~A = fmul double ~A, ~A~%" it x y)))
(defun pp-< (x y name)
  (a.prog1 (llvm-sym name :local t)
    (format t "~&~A = fcmp ult double ~A, ~A~%" it x y)))

;; 修正(2010/02/19): uiの型も引数(uitype)で渡すように変更  ※ 最終的には、ui自体が型情報を持つようするのが望ましい
(defun pp-ui2fp (ui uitype name)
  (a.prog1 (llvm-sym name :local t)
    (format t "~&~A = uitofp ~A ~A to double~%" it uitype ui)))

(def-llvm-ppd binary-exp (exp)
  (destructuring-bind (_ op lhs rhs) exp
    (declare (ignore _))
    (let ((lval (gen-code lhs))  ; 左側の式のコードを生成
          (rval (gen-code rhs))) ; 右側の式のコードを生成
      (case op
        (#\+ (pp-add lval rval "addtmp"))
        (#\- (pp-sub lval rval "subtmp"))
        (#\* (pp-mul lval rval "multmp"))
        (#\< (pp-ui2fp (pp-< lval rval "cmptmp") "i1" "booltmp")) ; 修正(2010/02/19): uiの型を引数で渡すように変更
        (otherwise
         (error "Invalid binary operator: ~A" op))))))

;;;;;;;;;;;;
;;; 関数呼出
;; 関数の登録/削除/検索関数
(let ((fn-module (make-hash-table :test #'equal)))
  (defun register-fn (name fn) (setf (gethash name fn-module) fn))
  (defun deregister-fn (name)  (remhash name fn-module))
  (defun find-fn (name)        (gethash name fn-module)))

;; 関数用の構造体
(defstruct fn
  name      ; string(関数名)
  arg-types ; string(型名)のlist
  arg-names ; string(引数名)のlist
  ret-type) ; string(戻り値の型)

;; コード生成: 関数呼出
(defun pp-call (fn args name)
  (a.prog1 (llvm-sym name :local t)  ; 修正(2010/02/21): もともとは第二引数が:globalになっていたが、:localが正しい
    (format t "~&~A = call ~A ~A(~{~A ~A~^,~})~%"
            it
            (fn-ret-type fn)
            (fn-name fn)
            (mapcan #'list (fn-arg-types fn) args))))

(def-llvm-ppd call-exp (exp)
  (let ((callee (call-exp-callee exp))
        (args   (call-exp-args   exp)))
    ;; 関数を検索する
    (let ((callee-fn (find-fn callee)))
      (when (null callee-fn)
        (error "Unknown function referenced: ~A" callee))
      
      ;; 引数の数をチェックする
      (when (/= (length (fn-arg-names callee-fn)) (length args))
        (error "Incorrect # arguments passed"))
      
      ;; 関数呼出コードを生成する
      (pp-call callee-fn
               (loop FOR a IN args COLLECT (gen-code a))
               "calltmp"))))

;;;;;;;;;;;;;;;;
;;; プロトタイプ
;; 関数構造体の作成と関数登録を行う
(def-llvm-ppd prototype (exp)
  (let ((name (prototype-name exp))
        (args (prototype-args exp)))
    ;; 関数作成: 引数と戻り値の型は、doubleで固定
    (let ((doubles (loop REPEAT (length args) COLLECT "double")))
      (let ((ft (make-fn :name (llvm-sym name :global)
                         :arg-types doubles
                         :arg-names args
                         :ret-type  "double")))

        ;; TODO: 重複登録をチェックする
        (register-fn name ft)  ; 関数登録
        (mapc #'bind-var args) ; ここで引数の変数を束縛する
        ft))))

;;;;;;;;;;;;
;;; 関数定義
(defun pp-basic-block (block-name fn)
  (with-slots (name arg-types arg-names ret-type) fn
    (format t "~&define ~A ~A (~{~A %~A~^,~}) {~%~A:~%"
            ret-type name
            (mapcan #'list arg-types arg-names)
            block-name)))

(defun pp-ret (retval)
  (format t "~&ret double ~A~%}~%" retval))

(def-llvm-ppd define (exp)
  (princ ; コード生成中にエラーが起こった場合に、中途半端なコードが出力されるのを防ぐために、with-output-to-stringで囲んで、最後にまとめてprincする
    (with-output-to-string (*standard-output*)  
      (let ((proto (define-proto exp))
            (body  (define-body  exp)))
        (clear-vars)                 ; 変数束縛を初期化する
        (let ((ft (gen-code proto))) ; プロトタイプを生成する(ここで新たに変数が束縛される)
          
          ;; 関数定義コード生成
          (pp-basic-block "entry" ft);; extern用の構造体(名前付きリスト)を追加
(defstruct (extern (:type list) :named) proto)

          (pp-ret (gen-code body))

          ;; TODO: オリジナルでは、ここで生成されたコードの妥当性(一貫性?)チェックを行っている
          ft)))))

;;;;;;;;;;;;;;
;;; extern宣言
(def-llvm-ppd extern (exp)
  (with-slots (ret-type name arg-types) (gen-code (extern-proto exp))
    (format t "~&declare ~A ~A (~{~A~^, ~})"
            ret-type name arg-types)))

以上。
かなり場当たり的かつ未整理なコードだけど、一応動きはする。

;; *print-pprint-dispatch*に、*llvm-pp-table*を束縛する
;; ※ 出力コードのインデントは手動で調整した
> (main-loop *llvm-pp-table*)
ready> 4+5;
define double @ () {
entry:
  %addtmp1 = fadd double 4.e+0, 5.e+0
  ret double %addtmp1
}

ready>  def foo(a b) a*a + 2*a*b + b*b;
define double @foo (double %a,double %b) {
entry:
  %multmp2 = fmul double %a, %a
  %multmp3 = fmul double 2.e+0, %a
  %multmp4 = fmul double %multmp3, %b
  %addtmp5 = fadd double %multmp2, %multmp4
  %multmp6 = fmul double %b, %b
  %addtmp7 = fadd double %addtmp5, %multmp6
  ret double %addtmp7
}

ready> def bar(a) foo(a, 4.0) + bar(31337);
define double @bar (double %a) {
entry:
  @calltmp8 = call double @foo(double %a,double 4.e+0)
  @calltmp9 = call double @bar(double 3.1337e+4)
  %addtmp10 = fadd double @calltmp8, @calltmp9
  ret double %addtmp10
}

ready> cos(1.234);
Error: Unknown function referenced: cos

ready> extern cos(x);
declare double @cos (double)

ready>  cos(1.234);
define double @ () {
entry:
  @calltmp11 = call double @cos(double 1.23399995e+0)
  ret double @calltmp11
}

もう少し整理したい気はするけど、三章もとりあえず終了。

おまけ

コード生成lisp(別のディスパッチ表を作成)

(defvar *lisp-pp-table* (copy-pprint-dispatch nil))

(defmacro def-lisp-ppd (type args &body body)
  (setf args (cons 'common-lisp:*standard-output* args))
  `(set-pprint-dispatch
    ',type
    (lambda ,args 
      (declare (ignorable common-lisp:*standard-output*))
      ,@body)
    0 
    *lisp-pp-table*))

(def-lisp-ppd number-exp (exp)
  (princ (number-exp-val exp)))

(def-lisp-ppd variable-exp (exp)
  (princ (variable-exp-name exp)))

(def-lisp-ppd binary-exp (exp)
  (format t "(~C ~A ~A)" 
          (binary-exp-op  exp)
          (binary-exp-lhs exp)
          (binary-exp-rhs exp)))

(def-lisp-ppd call-exp (exp)
  (format t "(~A~{ ~A~})"
          (call-exp-callee exp)
          (call-exp-args   exp)))

(def-lisp-ppd prototype (exp)
  (format t "~A ~A"
          (if (zerop (length #1=(prototype-name exp)))
              "anonymous"
            #1#)
          (or (prototype-args exp) "()")))

(def-lisp-ppd define (exp)
  (format t "(defun ~A ~A)"
          (define-proto exp)
          (define-body exp)))

実行。

> (main-loop *lisp-pp-table*)

ready> 4+5;
(defun anonymous () (+ 4 5))

ready> def foo(a b) a*a + 2*a*b + b*b;
(defun foo (a b) 
  (+ (+ (* a a)         ; 無駄が多い
        (* (* 2 a) b)) 
     (* b b)))

ready> def bar(a) foo(a, 4.0) + bar(31337);
(defun bar (a) 
  (+ (foo a 4.0) (bar 31337)))

ready> extern cos(x);
(EXTERN cos (x))       ; externはそのまま出力

ready> cos(1.234);
(defun anonymous () (cos 1.234))