読者です 読者をやめる 読者になる 読者になる

簡単なモナド実装

common lisp

Haskellモナドの簡単なものをcommon lispで実装してみたので、コードを載せておく。
※ 対象としたのは、maybeモナド、listモナド、stateモナド、continuationモナドの四つで、それぞれの定義は『All About Monads』の第三部を参考にさせてもらった。(各モナドの説明等も上記サイトを参照のこと。)

utility

(defpackage :monad.util
  (:use :common-lisp)
  (:export :def-do
           :def->>=*))
(in-package :monad.util)

;; [do記法]
;; (do (var1 <- monad-exp1)
;;      ...
;;     (var2 <- monad-exp2)
;;      ...
;;     (return (list var1 var2)))
;;
;; 特定のモナド用のdoマクロを定義するマクロ
;; ※ このマクロを呼び出すpackage内で、>>=が適切に定義されている必要がある
(defmacro def-do (&optional (name (intern "DO")) (>>= (intern ">>=")))
  `(defmacro ,name (&rest exps)
     `(progn ,@(def-do-expand ',>>= exps))))

(defun def-do-expand (>>= exps)
  (flet ((bind-exp? (exp)
           (and (consp exp) (string= (princ-to-string (second exp)) "<-"))))
    (when exps
      (destructuring-bind (exp . exps) exps
        (if (bind-exp? exp)
            (destructuring-bind (var <- monad-exp) exp
              (declare (ignore <-))
              `((,>>= ,monad-exp (lambda (,var)
                                   ,@(def-do-expand >>= exps)))))
          (cons exp (def-do-expand >>= exps)))))))


;; [>>=の可変長引数版]
;; (>>=* (return ...) monad-fn1 monad-fn2 ...)
;;
;; 特定のモナド用の>>=*マクロを定義するマクロ
;; ※ このマクロを呼び出すpackage内で、>>=が適切に定義されている必要がある
(defmacro def->>=* (&optional (name (intern ">>=*")) (>>= (intern ">>=")))
  `(defmacro ,name (initial-exp &rest exps)
     (def->>*-expand ',>>= (cons initial-exp exps))))

(defun def->>*-expand (>>= exps)
  (if (= 1 (length exps))
      (car exps)
    (destructuring-bind (left right . exps) exps
      (def->>*-expand >>= (cons `(,>>= ,left ,right) exps)))))

maybe

(defpackage :monad.maybe
  (:use :common-lisp)
  (:shadow :do :return)
  (:export :return
           :just
           :nothing
           :>>=
           :>>=*
           :do))
(in-package :monad.maybe)

;; return
(defun return (x) (just x))
(defun just (x) (values x t))
(defun nothing () (values nil nil))

;; >>= 
(defmacro >>= (monad-exp monad-fn)
  (let ((value (gensym)) (win (gensym)))
    `(multiple-value-bind (,value ,win) ,monad-exp
       (if ,win
           (funcall ,monad-fn ,value)
         (nothing)))))

(monad.util:def-do)
(monad.util:def->>=*)


;;; 使用例 ;;;
(defvar *data*
  `((:number (:odd  1 3 5 7) (:even 0 2 4 6))
    (:string (2 "ab" "12") (3 "cde" "345"))))
             
(defun lookup (key alist)
  (let ((rlt (assoc key alist)))
    (if rlt
        (just (cdr rlt))
      (nothing))))

(defun div (num denom)
  (if (zerop denom)
      (nothing)
    (just (/ num denom))))

;;
> (>>=* (lookup :number *data*)
        (lambda (data) (lookup :even data))
        (lambda (nums) (reduce #'+ nums)))
--> 12 T

> (>>=* (lookup :symbol *data*) 
        (lambda (data) (lookup :function data))
        (lambda (func) (funcall func 10)))
--> NIL NIL

> (do (data   <- (lookup :number *data*))
      (nums   <- (lookup :even data))
        (format t "~&Number list: ~A~%" nums)
      (num    <- (return (second nums)))
      (10/num <- (div 10 num))
        (return `(10 / ,num = ,10/num)))
Number list: (0 2 4 6)
--> (10 / 2 = 5) T

> (do (data   <- (lookup :number *data*))
      (nums   <- (lookup :even data))
        (format t "~&Number list: ~A~%" nums)
      (num    <- (return (first nums)))
      (10/num <- (div 10 num))  ; ゼロ除算
        (return `(10 / ,num = ,10/num)))
Number list: (0 2 4 6)
--> NIL NIL

list

(defpackage :monad.list
  (:use :common-lisp)
  (:shadow :do :return)
  (:export :return
           :return-list
           :>>=
           :>>=*
           :do))
(in-package :monad.list)

;; return
(defun return (x)  (list x))
(defun return-list (list) list)

;; >>=
(defun >>= (monad-exp monad-fn &aux (list monad-exp))
  (mapcan monad-fn (copy-list list)))

(monad.util:def-do)
(monad.util:def->>=*)

;;; 使用例 ;;;
> (do (elem1 <- (return-list '(1 2 3)))
      (elem2 <- (return-list (loop repeat elem1 collect elem1)))
      (elem3 <- (return (* elem2 elem2)))
        (return `#(,elem1 => ,elem3)))
--> (#(1 => 1) #(2 => 4) #(2 => 4) #(3 => 9) #(3 => 9) #(3 => 9))

(do (elem   <- '(1 2 3))
    (elem^2 <- (monad.list:return (expt elem 2)))
    (elem^3 <- (monad.list:return (expt elem 3)))
    (elem   <- (list elem elem^2 elem^3))
      (return elem))

state

(defpackage :monad.state
  (:use :common-lisp)
  (:shadow :do :return)
  (:export :return
           :>>=
           :>>=*
           :do  
           :run
           :bind-state
           :get-state
           :put-state))
(in-package :monad.state)

;; return
(defun return (value &key (new-state #1='#:undef))
  (lambda (state)
    (values value (if (eq #1# new-state) state new-state))))

;; >>=
(defun >>= (monad-exp monad-fn)
  (lambda (state)
    (multiple-value-bind (value new-state) (run state monad-exp)
      (let ((next-monad-exp (funcall monad-fn value)))
        (run new-state next-monad-exp)))))

(monad.util:def-do)
(monad.util:def->>=*)

;; stateつきでモナドを実行する
(defun run (state monad)
  (funcall monad state))

;; state操作系関数: (do (... <- exp) ...)のexp部分で使われることを想定
(defun get-state ()
  (lambda (state)
    (values state state)))

(defun put-state (new-state)
  (lambda (old-state)
    (values old-state new-state)))


;;; 使用例 ;;;
;; stateを使用しない
> (run :state
    (do (val   <- (return 10))
        (val^2 <- (return (* val val)))
        (val^6 <- (return (* val^2 val^2 val^2)))
          (return val^6)))
--> 1000000 :STATE

;; 途中でstateを取得して、表示する
> (run :state
    (do (val   <- (return 10))
        (state <- (get-state))
          (format t "~&STATE = ~S~%" state)
        (val^6 <- (return (expt val 6)))
          (return val^6)))
STATE = :STATE
--> 1000000 :STATE

;;; stateを変更する
> (run 'original-state
    (do  (val       <- (return 10))
         (old-state <- (get-state))
         (val^2     <- (return (* val val)))
         (_         <- (put-state 'new-state))
           (declare (ignore _))
         (val^6     <- (return (* val^2 val^2 val^2)))
         (new-state <- (get-state))
           (return `(:value ,val^6 :old ,old-state :new ,new-state))))
--> (:VALUE 1000000 :OLD ORIGINAL-STATE :NEW NEW-STATE) NEW-STATE

continuation

(defpackage :monad.continuation
  (:use :common-lisp)
  (:shadow :do :return :return-from)
  (:export :return
           :return-from
           :>>=
           :>>=*
           :do
           :run
           :call/cc
           :call/cc-block))
(in-package :monad.continuation)

(defvar *initial-continuation* #'identity)
(defun run (monad &key (cont *initial-continuation*))
  (funcall monad cont))

;; return
(defun return (value)
  (lambda (cont)
    (funcall cont value)))

;; >>=
(defun >>= (monad-exp monad-fn &aux (cont monad-exp))
  (lambda (cc)
    (funcall cont 
             (lambda (value) 
               (run (funcall monad-fn value) :cont cc)))))

(monad.util:def-do)
(monad.util:def->>=*)


;; call current continuation
(defun call/cc (fn)
  (lambda (cc)
    (run (funcall fn
                  (lambda (value)
                    (lambda (unused-cont)
                      (declare (ignore unused-cont))
                      (funcall cc value))))
         :cont cc)))

;; call/ccを、common lispのblock + return-fromと似た感じで使えるようにしたもの
;;
;; (block ABC
;;    ... ... 
;;    (common-lisp:return-from ABC value))
;;
;; (call/cc-block ABC
;;    ... ...
;;    (monad.continuation:return-from ABC value))
(defmacro call/cc-block (return-block &body body)
  `(call/cc (lambda (,return-block)
              (declare (ignorable ,return-block))
              ,@body)))

;; continuationを取得する: call/cc-blockとの併用を想定
(defun get/c (cont)
  (lambda (value)
    (funcall (funcall cont value) nil)))

(defun return-from (cont value)
  (funcall cont value))


;;; 使用例 ;;;
;; ※ continuationモナドを返す式は、
;;    (progn ... (do ...) ...)や
;;    (let (...) (when ... (return-from ...)) ...)
;;    などのように式の途中に置くことは出来ない。
;;      --> 必ず一連の式の末尾に置く必要がある。

> (defvar *level0*)
> (defvar *level1*)

;; x1が偶数ならx1+x1/2を、奇数ならx1+1を、x2にセットしている
> (run
   (do (x1 <- (return 9))
       (x2 <- (call/cc-block LEVEL0  ; LEVEL0
                (setf *level0* (get/c LEVEL0))  ; 継続を保存しておく
                (do (y1 <- (return x1))
                    (y2 <- (call/cc-block LEVEL1  ; LEVEL1
                             (setf *level1* (get/c LEVEL1))  ; 継続を保存しておく
                             (if (oddp y1)
                                 (do (z1 <- (return (+ y1 1)))
                                       (return-from LEVEL0 z1)) ; 奇数の場合、LEVEL0から脱出(?)する
                               (return (/ y1 2)))))
                    (y3 <- (return (+ y1 y2))) ; 奇数の場合、ここは実行されない
                      (return y3))))
         (return `(,x1 -> ,x2))))
--> (9 -> 10)

;; 保存しておいた継続を使って、LEVEL0から脱出(実行再開)する。
> (return-from *level0* 100) ; --> x2に100をバインドして、以降の処理を実行
--> (9 -> 100)

;; 保存しておいた継続を使って、LEVEL1から脱出(実行再開)する。
> (return-from *level1* 100)  ; --> y2に100をバインドして、以降の処理を実行
--> (9 -> 109)