2012-10-12 4 views
1

Strassen 행렬 곱셈을 Python으로 구현하려고합니다. 나는 그것이 다소 효과가있다. 내 코드는 다음과 같습니다.Strassen 행렬 곱셈 - 버그가 있지만 닫음. 여전히

a = [[1,1,1,1],[2,2,2,2],[3,3,3,3],[4,4,4,4]] 
b = [[5,5,5,5],[6,6,6,6],[7,7,7,7],[8,8,8,8]] 

def new_m(p, q): # create a matrix filled with 0s 
    matrix = [[0 for row in range(p)] for col in range(q)] 
    return matrix 

def straight(a, b): # multiply the two matrices 
    if len(a[0]) != len(b): # if # of col != # of rows: 
     return "Matrices are not m*n and n*p" 
    else: 
     p_matrix = new_m(len(a), len(b[0])) 
     for i in range(len(a)): 
      for j in range(len(b[0])): 
       for k in range(len(b)): 
        p_matrix[i][j] += a[i][k]*b[k][j] 
    return p_matrix 

def split(matrix): # split matrix into quarters 
    a = matrix 
    b = matrix 
    c = matrix 
    d = matrix 
    while(len(a) > len(matrix)/2): 
     a = a[:len(a)/2] 
     b = b[:len(b)/2] 
     c = c[len(c)/2:] 
     d = d[len(d)/2:] 
    while(len(a[0]) > len(matrix[0])/2): 
     for i in range(len(a[0])/2): 
      a[i] = a[i][:len(a[i])/2] 
      b[i] = b[i][len(b[i])/2:] 
      c[i] = c[i][:len(c[i])/2] 
      d[i] = d[i][len(d[i])/2:] 
    return a,b,c,d 

def add_m(a, b): 
    if type(a) == int: 
     d = a + b 
    else: 
     d = [] 
     for i in range(len(a)): 
      c = [] 
      for j in range(len(a[0])): 
       c.append(a[i][j] + b[i][j]) 
      d.append(c) 
    return d 

def sub_m(a, b): 
    if type(a) == int: 
     d = a - b 
    else: 
     d = [] 
     for i in range(len(a)): 
      c = [] 
      for j in range(len(a[0])): 
       c.append(a[i][j] - b[i][j]) 
      d.append(c) 
    return d 


def strassen(a, b, q): 
    # base case: 1x1 matrix 
    if q == 1: 
     d = [[0]] 
     d[0][0] = a[0][0] * b[0][0] 
     return d 
    else: 
     #split matrices into quarters 
     a11, a12, a21, a22 = split(a) 
     b11, b12, b21, b22 = split(b) 

     # p1 = (a11+a22) * (b11+b22) 
     p1 = strassen(add_m(a11,a22), add_m(b11,b22), q/2) 

     # p2 = (a21+a22) * b11 
     p2 = strassen(add_m(a21,a22), b11, q/2) 

     # p3 = a11 * (b12-b22) 
     p3 = strassen(a11, sub_m(b12,b22), q/2) 

     # p4 = a22 * (b12-b11) 
     p4 = strassen(a22, sub_m(b12,b11), q/2) 

     # p5 = (a11+a12) * b22 
     p5 = strassen(add_m(a11,a12), b22, q/2) 

     # p6 = (a21-a11) * (b11+b12) 
     p6 = strassen(sub_m(a21,a11), add_m(b11,b12), q/2) 

     # p7 = (a12-a22) * (b21+b22) 
     p7 = strassen(sub_m(a12,a22), add_m(b21,b22), q/2) 


     # c11 = p1 + p4 - p5 + p7 
     c11 = add_m(sub_m(add_m(p1, p4), p5), p7) 

     # c12 = p3 + p5 
     c12 = add_m(p3, p5) 

     # c21 = p2 + p4 
     c21 = add_m(p2, p4) 

     # c22 = p1 + p3 - p2 + p6 
     c22 = add_m(sub_m(add_m(p1, p3), p2), p6) 

     c = new_m(len(c11)*2,len(c11)*2) 
     for i in range(len(c11)): 
      for j in range(len(c11)): 
       c[i][j]     = c11[i][j] 
       c[i][j+len(c11)]   = c12[i][j] 
       c[i+len(c11)][j]   = c21[i][j] 
       c[i+len(c11)][j+len(c11)] = c22[i][j] 

     return c 

print "Strassen Outputs:" 
print strassen(a, b, 4) 
print "Should be:" 
print straight(a, b) 

적절한 출력을 참조하기 위해 직선 행렬 곱셈을 포함 시켰습니다. 기본적으로이 상황이 발생합니다

쉬트 라쎈 출력을 :

[[10, 14, 22, 26], [32, 36, 48, 52], [58, 66, 70, 78], [80, 88, 96, 104]] 

은 다음과 같아야합니다

[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]] 

나는 문제의 원인 내가 그것을 해결할 수없는 의미 무엇인지 확실하지 않다! 이

답변

2

을 안 :

# p4 = a22 * (b12-b11) 
p4 = strassen(a22, sub_m(b12,b11), q/2) 

가 될 : 대신

# p4 = a22 * (b21-b11) 
p4 = strassen(a22, sub_m(b21,b11), q/2) 

?

~/coding$ python -i strass.py 
Strassen Outputs: 
[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]] 
Should be: 
[[26, 26, 26, 26], [52, 52, 52, 52], [78, 78, 78, 78], [104, 104, 104, 104]] 
>>> import numpy 
>>> def check(): 
...  for i in range(100): 
...   a = numpy.random.randint(0, 10,size=(4,4)).tolist() 
...   b = numpy.random.randint(0, 10,size=(4,4)).tolist() 
...   assert strassen(a,b,4) == straight(a,b) 
...   assert (numpy.array(strassen(a,b,4)) == numpy.dot(a,b)).all() 
...  print 'hooray!' 
... 
>>> check() 
hooray! 
+0

AAAHHH, 부주의로 인한 실수를 다른 버전을 썼다! 정말 고마워! – benwiz

0

나는 ... 추가()와 하위()를 단순화하기 위해 NumPy와 함께

import numpy as np 
def straight(a, b): 
    if len(a[0]) != len(b): return "Matrices are not m*n and n*p" 
    p_matrix = np.zeros((len(a), len(b[0]))) 
    p_matrix += [[np.sum([a[i][k] * b[k][j] for k in range(len(b))]) for j in range(len(b[0]))] for i in range(len(a))] 
    return p_matrix 
def split(matrix): # split matrix into quarters 
    row, col = matrix.shape 
    return matrix[:row//2, :col//2], matrix[:row//2, col//2:], matrix[row//2:, :col//2], matrix[row//2:, col//2:] 
def strassen(a, b): 
    q = len(a) 
    if q == 1: # base case: 1x1 matrix 
     return a * b 
    a11, a12, a21, a22 = split(a) 
    b11, b12, b21, b22 = split(b) 
    p1 = strassen(a11 + a22, b11 + b22) # p1 = (a11 + a22) * (b11 + b22) 
    p2 = strassen(a21 + a22, b11)  # p2 = (a21 + a22) * b11 
    p3 = strassen(a11, b12 - b22)  # p3 = a11 * (b12 - b22) 
    p4 = strassen(a22, b21 - b11)  # p4 = a22 * (b21 - b11) 
    p5 = strassen(a11 + a12, b22)  # p5 = (a11 + a12) * b22 
    p6 = strassen(a21 - a11, b11 + b12) # p6 = (a21 - a11) * (b11 + b12) 
    p7 = strassen(a12 - a22, b21 + b22) # p7 = (a12 - a22) * (b21 + b22) 
    c11 = p1 + p4 - p5 + p7 # c11 = p1 + p4 - p5 + p7 
    c12 = p3 + p5   # c12 = p3 + p5 
    c21 = p2 + p4   # c21 = p2 + p4 
    c22 = p1 + p3 - p2 + p6 # c22 = p1 + p3 - p2 + p6 
    c = np.vstack((np.hstack((c11, c12)), np.hstack((c21, c22)))) 
    return c 
def check(): 
    a = np.random.randint(0, 10, size=(16, 16)) 
    b = np.random.randint(0, 10, size=(16, 16)) 
    assert (strassen(a, b) == straight(a, b)).all() 
    assert (np.array(strassen(a, b)) == np.dot(a, b)).all() 
    print('Hooray!') 
check()