稀疏矩阵的乘法

太太太太太太太太太太难了.

严蔚敏书上的算法, 我只是写出来, 我这辈子估计都理解不了.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
#include <stdio.h>
#include <stdlib.h>

#define MAXSIZE 12500 //非零元素最大个数

typedef int ElemType; //矩阵元素类型, 这里是整型

typedef struct
{
int i, j; //元素的坐标
ElemType e; //元素的值
}Triple;

typedef struct
{
int * rpos; //这个指针会指向一个动态数组, 存储各行第一个元素的位置
Triple data[MAXSIZE]; //非零元素数组, 0 号元素未使用
int mu, nu, tu; //行数, 列数, 非零元个数
}TSMatrix;

//初始化矩阵
void Init(TSMatrix * TSM)
{
int i, j, num;
printf("输入非零元素的个数, 行数, 列数: \n");
scanf("%d %d %d", &TSM->tu, &TSM->mu, &TSM->nu); //输入非零元素的个数
for(i=1; i<=TSM->tu; i++)
{
printf("输入坐标(从 1 开始), 值: \n");
scanf("%d %d %d", &TSM->data[i].i, &TSM->data[i].j, &TSM->data[i].e);
}
TSM->rpos = (int*)malloc(sizeof(int)*(TSM->mu+1)); //根据矩阵的行数生成一个数组, 存储各行第一个元素出现的位置
TSM->rpos[1] = 1; //第一行第一个元素在非零数组第一位
for(i=2; i<=TSM->mu; i++) //从第二行开始, 每行第一个元素的位置 = 前一行第一个元素的位置 + 前一行元素的个数
{
j=1; //用 j 遍历整个非零数组
num=0; //用 num 统计某一行元素个数
while(j<=TSM->tu) //遍历非零数组
{
if(TSM->data[j].i == i-1) //如果元素是前一行的
{
num++; //元素个数 ++
}
j++;
}
TSM->rpos[i] = TSM->rpos[i-1] + num; //前一行第一个元素的位置 + 前一行元素的个数
/*
如果前一行没有元素, 那么这一行第一个元素的位置 = 前一行第一个元素的位置
而在做某些操作的时候, 是读取到非零元素, 然后根据它的行号, 通过 rpos[行号] 来确定位置
某一行没有元素, 在那一行读取不到元素, 也就不会做什么事情
*/
}
}

void Show(TSMatrix * TSM)
{
int i;
printf("矩阵信息如下: \n");
for(i=1; i<=TSM->tu; i++)
{
printf("%-2d %-2d %-2d\n", TSM->data[i].i, TSM->data[i].j, TSM->data[i].e);
}
printf("\n");
printf("\n矩阵首行元素位置如下: \n");
for(i=1; i<=TSM->mu; i++)
{
printf("%-2d : %-2d \t", i, TSM->rpos[i]);
}
printf("\n\n");
}

int Mult(TSMatrix M, TSMatrix N, TSMatrix *Q) //矩阵 M 乘矩阵 N 得到矩阵 Q
{
int arow, brow, ccol, p, tp, t, q;
int * ctemp = (int*)malloc(sizeof(int)*(N.nu+1));
if(M.nu != N.mu) return 0; //如果 M 的列不等于 N 的行, 不符合矩阵相乘的条件
Q->rpos = (int*)malloc(sizeof(int)*(M.mu+1));
Q->mu = M.mu; Q->nu = N.nu; Q->tu = 0; //Q 初始化, 相乘后, Q 的行等于 M 的行, Q 的列等于 N 的列
if(M.tu * N.tu != 0) //这里确保有一个矩阵不为空, 如果是全为 0 的矩阵, 相乘的结果肯定是空矩阵
{
for(arow=1; arow<=M.mu; ++arow) //处理 M 的每一行
{
for(p=1; p<=N.nu+1; p++)
{
ctemp[p] = 0;
}
Q->rpos[arow] = Q->tu + 1; //记录 Q 各行首元素的位置
if(arow<M.mu)
{
tp = M.rpos[arow+1];
//tp 下一行首元素的位置, 遍历本行元素时(下一个 for 循环)用到
//如果不是最后一行, tp 可以从 rpos 数组中得到
}
else
{
tp = M.tu + 1; //如果是最后一行, 那么 tp 等于总个数加一
}
for(p=M.rpos[arow]; p<tp; p++) //遍历当前行, p 从本行首元素开始, 大于等于 tp 时跳出循环
{
brow = M.data[p].j; //矩阵相乘是 M列号==N行号 的元素
if(brow<N.mu)
{
//这个 if 跟上面的作用相似, t 用于循环 N 中的列
t = N.rpos[brow+1];
}
else
{
t = N.tu + 1;
}
for(q=N.rpos[brow]; q<t; q++)
{
ccol = N.data[q].j; //乘积元素在 Q 中的列号
ctemp[ccol] += M.data[p].e * N.data[q].e; //Q(arow,ccol) 元素的值 //懵逼...
}
}
for(ccol=1; ccol<=Q->nu; ccol++)
{
if(ctemp[ccol])
{
if(++Q->tu > MAXSIZE) return 0;
Q->data[Q->tu].i = arow;
Q->data[Q->tu].j = ccol;
Q->data[Q->tu].e = ctemp[ccol];
}
}
}
}
return 1;
}

int main(int argc, char const *argv[])
{
TSMatrix a, b, c;
Init(&a);
Show(&a);
Init(&b);
Show(&b);

Mult(a, b, &c);
Show(&c);

return 0;
}