方案中的矩阵乘法,列表列表

Matrix multiplication in scheme, List of lists

我开始研究Scheme,但我不了解其中的一些内容。我正在使用DrRacket。

我写了以下代码:

1
2
3
4
(define mult_mat
  (λ (A B)
    (Trans_Mat (map (λ (x) (mul_Mat_vec A x))
                    (Trans_Mat B)))))

使用以下功能:

1
2
3
(define Trans_Mat
  (λ (A)
    (apply map (cons list A))))
1
2
3
4
(define mul_Mat_vec
  (λ (A v)
    (map (λ (x) (apply + (map * x v)))
         A)))

mult_mat中,我将矩阵A乘以转置矩阵B的每个向量。
效果很好。

我在网上找到了一个以我不了解的方式进行乘法运算的代码:

1
2
3
4
5
6
7
8
(define (matrix-multiply matrix1 matrix2)
  (map
   (λ (row)
     (apply map
       (λ column
         (apply + (map * row column)))
       matrix2))
   matrix1))

在此代码中,row是矩阵A列表的列表,但我不了解column的更新方式。

这部分代码:(apply + (map * row column))是向量row和向量column

的点积

例如:A是矩阵2X3,B是矩阵3X2,如果我写的是1而不是(apply + (map * row column)),那么我将得到一个矩阵2X2,其条目值为1

我不知道它是如何工作的。

谢谢。


啊,旧的( apply map foo _a_list_ )技巧。非常聪明。

实际上(apply map (cons list A))(apply map list A)相同。这就是定义apply的工作方式。

尝试一些具体示例通常有助于"获取":

1
2
3
4
5
6
7
8
9
(apply map       list '((1 2 3)  (10 20 30)) )
=
(apply map (cons list '((1 2 3)  (10 20 30))))
=
(apply map (list list  '(1 2 3) '(10 20 30) ))
=
(      map       list  '(1 2 3) '(10 20 30)  )
=
'((1 10) (2 20) (3 30))

,以便最后一个参数'((1 2 3) (10 20 30))的元素被拼接成整体apply map ...形式。

矩阵转置(真正的列表列表)。

所以你有

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
(define (mult_mat A B)
    (Trans_Mat (map (λ (B_column) (mul_Mat_vec A B_column))
                    (Trans_Mat B))))

(define (Trans_Mat A)
    (apply map list A))

(define (mul_Mat_vec A v)
    (map (λ (A_row) (apply + (map * A_row v)))
         A))

(define (matrix-multiply A B)
  (map
    (λ (A_row)
      (apply map
             (λ B_column
               (apply + (map * A_row B_column)))
             B))
    A))

请注意,它是(λ B_column ...,没有括号。在((λ args ...) x y z)中,当输入lambda时,args获取打包在列表中的所有参数:

1
2
3
4
((λ args ...) x y z)
=
(let ([args (list x y z)])
  ...)

也请注意

1
2
3
4
      (apply map
             (λ B_column
               (apply + (map * A_row B_column)))
             B)

遵循相同的"棘手"模式。实际上与

相同

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
      (apply map (cons
             (λ B_column
               (apply + (map * A_row B_column)))
             B    ) )
=
      (      map
             (λ B_column
                (apply + (map * A_row B_column)))
             B_row1
             B_row2
             ....
             B_rowN )
=
     (cons (let ([B_column_1 (map car B)])
              (apply + (map * A_row B_column_1)))
           (map (λ B_column
                    (apply + (map * A_row B_column)))
             (cdr B_row1)
             (cdr B_row2)
             ....
             (cdr B_rowN)) )
=
     (cons
       (apply (λ B_column (apply + (map * A_row B_column)))
              (map car B))
       (apply map
              (λ B_column
                 (apply + (map * A_row B_column)))
              (map cdr B)))

通过map的定义。

因此,通过应用map,矩阵将"打开"到其元素列表中的行,然后,当多参数map作为这些参数的行开始处理这些行时,lambda函数将相应地统一应用于每行的后续数字;从而达到与显式换位相同的效果。但是现在增加的好处是,我们不需要像使用第一个版本那样将结果转回正确的形式。

这非常聪明,也很不错。

因此,基于所有这些理解,让我们尝试重新阅读原始代码,看看我们是否也可以照原样查看它。

1
2
3
4
5
6
7
8
(define (matrix-multiply matrix1 matrix2)
  (map
   (λ (row)
     (apply map
       (λ column      ;; <<------ no parens!
         (apply + (map * row column)))
       matrix2))
   matrix1))

内容为:对于matrix1中的每个row,在matrix2上的多个参数maplambdamatrix2本身也是一个行列表;当我们在行上使用arg- map时,lambda依次应用于行中的每一列。

因此,对于matrix1中的每一行,对于matrix2中的每一列,将该行和该列逐元素相乘并求和;因此,将每一行转换为这些总和的列表。仅当行的长度和每列的长度相同时,这显然可行:如果第一个矩阵的"宽度"和第二个矩阵的"高度"相同。


如果您更喜欢使用while循环(对于初学者可能更容易),建议将问题分为7个主要的辅助函数(以及一些其他简单函数):

(到目前为止)这不是最有效的方法,但是很容易理解

  • getRow mat i:获取矩阵mat的第i行(列表列表)

    1
    2
     (define (getRow mat i)
        (nthElement mat i))
    1
    2
    3
    4
    (define (nthElement lisT n)
        (if (= n 0)
            (car lisT)                                                
            (nthElement (cdr lisT) (- n 1))))
  • getCol mat i:获取矩阵mat的第i列(列表列表)

    1
    2
    3
    4
    5
    6
    7
    (define (getCol mat i)
        (define col (list))
        (define row 0)
        (while (< row (length mat))
            (set! col (append col (list (valueAtIJ mat row i))))
            (set! row (+ row 1)))
         col)
    1
    2
    (define (valueAtIJ mat i j)
        (nthElement (nthElement mat i) j))
  • listMult list1 list2:在两个列表上执行逐元素乘法

    1
    2
    3
    4
    (define (listMult list1 list2)
        (if (not (null? list1))
            (cons (* (car list1) (car list2)) (listMult (cdr list1) (cdr list2)))
            null))
  • sum aList:计算列表中所有元素的总和。

    1
    2
    3
    4
    (define (sum aList)
        (if (null? aList)
            0
            (+ (car aList) (sum (cdr aList)))))
  • length aList:查找列表的长度

    1
    2
    3
    4
    (define (length lisT)
        (if (null? lisT)                                              
            0
            (+ 1 (length (cdr lisT)))))
  • newMatrix m n val:创建一个由val

    填充的m x n矩阵

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    (define (newMatrix m n val)                        
        (define i 0)
        (define row (list val))
        (define mat (list))
        (if (= n 0)
            (list)                                          
            (begin
                (while (< i (- n 1))
                    (set! row (append row (list val)))
                    (set! i (+ i 1)))
                (set! i 0)
                (while (< i m)
                    (set! mat (append mat (list row)))    
                    (set! i (+ i 1)))
        mat)))
  • setValueAtIJ mat i j val:在mat中的位置i,j处设置值val(从0开始)

    1
    2
    3
    (define (setValueAtIJ mat i j val)
        (set! mat (setNthElementFinal mat i (setNthElementFinal (nthElement mat i) j val)))
        mat)
  • 这些都可以组合起来以创建矩阵乘法函数

    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    (define (matrixMult mat1 mat2)
        (define mat1Dim (list (length mat1) (length (nthElement mat1 0))))      
        (define mat2Dim (list (length mat2) (length (nthElement mat2 0))))      
        (define i 0)
        (define j 0)
        (define newMat (newMatrix (car mat1Dim) (car (cdr mat2Dim)) 0))        
        (if (not (= (car (cdr mat1Dim)) (car mat2Dim)))
            null                                                                
            (begin
                (while (< i (length newMat))
                    (while (< j (length (nthElement newMat 0)))
                        (set! newMat (setValueAtIJ newMat i j (sum (listMult (getRow mat1 i) (getCol mat2 j)))))
                        (set! j (+ j 1)))
                    (set! j 0)
                    (set! i (+ i 1)))
            newMat)))