Implement copy-on write

背景

xv6 使用 fork() 系统调用创建子进程时,需要将父进程的地址空间进行 深拷贝 ,即将页表和实际物理空间同时进行拷贝,以实现父进程和子进程地址空间的独立性。但很多时候,如 shell 程序,fork() 通常与 exec() 搭配使用,首先使用 fork() 创建子进程,随后在子进程中使用 exec() 将指定的程序加载到当前地址空间,这样在 fork() 中进行的地址空间拷贝就白白浪费了。

本实现要求实现一个写时复制(copy-on write)的 fork() 系统调用。具体来说,在进行虚拟内存拷贝时,不直接进行物理内存的拷贝,只是将父进程的页表复制给子进程,这样子进程和父进程的每个虚拟页面都指向了同一个物理页面,当子进程需要对某个虚拟页面进行写入时,为了保证父进程和子进程之间的独立性,子进程此时将进行物理内存的分配和拷贝,再进行写入。

实现方案

根据提示,可以将上述的写时复制的思路用 异常 的方式来实现。

首先可以利用页表项的 flags 中的 RSW 位来表示页表项是否为 COW 页,以便后续的异常处理。

修改 uvmcopy() ,将物理页面的分配操作去除,只是进行页表的拷贝,并将父进程和子进程的对应页表项的 PTE_W 置 0(以便在对 COW 页进行写入时陷入内核)、PTE_COW 置 1。

修改 usertrap(),当陷入内核时,内核通过查看 scause 寄存器(见下图)以及页表项的 PTE_W 和 PTE_COW 位,识别到陷入原因是发生在 COW 页上的 store page fault(寄存器值为 15)时,进行对应的异常处理:使用 kalloc() 为其分配物理页面,并将其页表项指向的物理地址数据拷贝到新分配的物理地址下,实现物理内存的拷贝。此时由于页表映射发生了改变,需要插入新的页表项,并删除旧的页表项。在处理了 COW 异常之后,该页面将不再是一个 COW 页,因此需要将 PTE_W 置 1、PTE_COW 置 0。

为了后续实现的方便,可以将 COW 页的判断和 COW 页的异常处理分别封装为两个函数:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
int iscowpage(pagetable_t pgtbl, uint64 va) {
if (va >= MAXVA) return 0;
pte_t *pte = walk(pgtbl, va, 0);
if (pte == 0) return 0;
if ((*pte & PTE_V) == 0) return 0;
if ((*pte & PTE_U) == 0) return 0;
return *pte & PTE_COW;
}

int cowfault(pagetable_t pagetable, uint64 va) {
uint64 va0 = PGROUNDDOWN(va);
pte_t* pte;
if((pte = walk(pagetable, va0, 0)) == 0) return -1;

uint64 flags = PTE_FLAGS(*pte);
uint64 pa0 = PTE2PA(*pte);

flags &= (~PTE_COW); // clear COW bit
flags |= PTE_W; // set write bit

uint64 mem;
if ((mem = (uint64)kalloc()) == 0) return -1;
memmove((void *)mem, (void *)pa0, PGSIZE);

// remove old PTE
uvmunmap(pagetable, va0, 1, 1);

// install new PTE
if(mappages(pagetable, va0, PGSIZE, mem, flags) < 0){
kfree((void *)mem);
return -1;
}
return 0;
}

此外,还需要为每个物理页面引入 引用计数(reference count) ,页面创建时计数为 1,每次添加或移除指向该物理地址的页表项都增加或减少引用计数,当引用计数为 0 时释放该物理页面。这里有一个实现的技巧:将引用计数的减少放到 kfree() 中,在 kfree() 中根据引用计数的大小决定是否释放物理页面。

最后,也是很容易忽视的一点,修改 copyout() 以实现对 COW 页的支持。刚开始看到这个提示的时候我很疑惑,前面的工作貌似已经足够实现 COW 了,为什么还要修改 copyout?原来 xv6 对 COW 页进行写时复制都是基于 store page fault,即当尝试写入一个 PTE_W 为 0 的页面时触发异常,导致陷入内核,再由内核进行 COW 页面的异常处理,其中陷入内核的操作是由硬件自动来完成的,具体来说,是在虚实地址转换阶段由 MMU 来完成的。而 copyout() 是运行在内核态下的函数,其地址转换是由内核中的函数 walk() 来实现的,因而不会自动触发异常并交由异常处理程序来处理,而需要手动来完成。由于前面已经将 COW 页的判断和处理封装成了函数,因此对 copyout() 的修改很简单:

1
2
3
if (iscowpage(pagetable, va0)) {
cowfault(pagetable, va0);
}

代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
diff --git a/kernel/defs.h b/kernel/defs.h
index 3564db4..f5a9d8d 100644
--- a/kernel/defs.h
+++ b/kernel/defs.h
@@ -63,6 +63,7 @@ void ramdiskrw(struct buf*);
void* kalloc(void);
void kfree(void *);
void kinit(void);
+void incrfcount(void*);

// log.c
void initlog(int, struct superblock*);
@@ -145,6 +146,8 @@ void trapinit(void);
void trapinithart(void);
extern struct spinlock tickslock;
void usertrapret(void);
+int iscowpage(pagetable_t, uint64);
+int cowfault(pagetable_t, uint64);

// uart.c
void uartinit(void);
@@ -170,6 +173,7 @@ uint64 walkaddr(pagetable_t, uint64);
int copyout(pagetable_t, uint64, char *, uint64);
int copyin(pagetable_t, char *, uint64, uint64);
int copyinstr(pagetable_t, char *, uint64, uint64);
+pte_t* walk(pagetable_t, uint64, int);

// plic.c
void plicinit(void);
diff --git a/kernel/kalloc.c b/kernel/kalloc.c
index fa6a0ac..5872b85 100644
--- a/kernel/kalloc.c
+++ b/kernel/kalloc.c
@@ -14,6 +14,11 @@ void freerange(void *pa_start, void *pa_end);
extern char end[]; // first address after kernel.
// defined by kernel.ld.

+#define PA2RFIDX(pa) ((((uint64)pa) - KERNBASE) / PGSIZE)
+
+int rfcount[(PHYSTOP - KERNBASE) / PGSIZE];
+struct spinlock rflock;
+
struct run {
struct run *next;
};
@@ -27,6 +32,7 @@ void
kinit()
{
initlock(&kmem.lock, "kmem");
+ initlock(&rflock, "rflock");
freerange(end, (void*)PHYSTOP);
}

@@ -51,15 +57,17 @@ kfree(void *pa)
if(((uint64)pa % PGSIZE) != 0 || (char*)pa < end || (uint64)pa >= PHYSTOP)
panic("kfree");

- // Fill with junk to catch dangling refs.
- memset(pa, 1, PGSIZE);
-
- r = (struct run*)pa;
-
- acquire(&kmem.lock);
- r->next = kmem.freelist;
- kmem.freelist = r;
- release(&kmem.lock);
+ acquire(&rflock);
+ if(--rfcount[PA2RFIDX(pa)] <= 0){
+ memset(pa, 1, PGSIZE);
+ // Fill with junk to catch dangling refs.
+ r = (struct run*)pa;
+ acquire(&kmem.lock);
+ r->next = kmem.freelist;
+ kmem.freelist = r;
+ release(&kmem.lock);
+ }
+ release(&rflock);
}

// Allocate one 4096-byte page of physical memory.
@@ -76,7 +84,15 @@ kalloc(void)
kmem.freelist = r->next;
release(&kmem.lock);

- if(r)
+ if(r) {
memset((char*)r, 5, PGSIZE); // fill with junk
+ rfcount[PA2RFIDX(r)] = 1;
+ }
return (void*)r;
}
+
+void incrfcount(void* pa){
+ acquire(&rflock);
+ ++rfcount[PA2RFIDX(pa)];
+ release(&rflock);
+}
\ No newline at end of file
diff --git a/kernel/riscv.h b/kernel/riscv.h
index 1691faf..a6ba9e7 100644
--- a/kernel/riscv.h
+++ b/kernel/riscv.h
@@ -343,6 +343,8 @@ sfence_vma()
#define PTE_W (1L << 2)
#define PTE_X (1L << 3)
#define PTE_U (1L << 4) // 1 -> user can access
+#define PTE_COW (1L << 8) // 1 -> is a COW page
+

// shift a physical address to the right place for a PTE.
#define PA2PTE(pa) ((((uint64)pa) >> 12) << 10)
diff --git a/kernel/trap.c b/kernel/trap.c
index a63249e..0fb7687 100644
--- a/kernel/trap.c
+++ b/kernel/trap.c
@@ -29,6 +29,42 @@ trapinithart(void)
w_stvec((uint64)kernelvec);
}

+
+int iscowpage(pagetable_t pgtbl, uint64 va) {
+ if (va >= MAXVA) return 0;
+ pte_t *pte = walk(pgtbl, va, 0);
+ if (pte == 0) return 0;
+ if ((*pte & PTE_V) == 0) return 0;
+ if ((*pte & PTE_U) == 0) return 0;
+ return *pte & PTE_COW;
+}
+
+int cowfault(pagetable_t pagetable, uint64 va) {
+ uint64 va0 = PGROUNDDOWN(va);
+ pte_t* pte;
+ if((pte = walk(pagetable, va0, 0)) == 0) return -1;
+
+ uint64 flags = PTE_FLAGS(*pte);
+ uint64 pa0 = PTE2PA(*pte);
+
+ flags &= (~PTE_COW); // clear COW bit
+ flags |= PTE_W; // set write bit
+
+ uint64 mem;
+ if ((mem = (uint64)kalloc()) == 0) return -1;
+ memmove((void *)mem, (void *)pa0, PGSIZE);
+
+ // remove old PTE
+ uvmunmap(pagetable, va0, 1, 1);
+
+ // install new PTE
+ if(mappages(pagetable, va0, PGSIZE, mem, flags) < 0){
+ kfree((void *)mem);
+ return -1;
+ }
+ return 0;
+}
+
//
// handle an interrupt, exception, or system call from user space.
// called from trampoline.S
@@ -67,7 +103,12 @@ usertrap(void)
syscall();
} else if((which_dev = devintr()) != 0){
// ok
- } else {
+ } else if (r_scause() == 15 && iscowpage(p->pagetable, r_stval())) {
+ if (cowfault(p->pagetable, r_stval()) < 0) {
+ p->killed = 1;
+ }
+ }
+ else {
printf("usertrap(): unexpected scause %p pid=%d\n", r_scause(), p->pid);
printf(" sepc=%p stval=%p\n", r_sepc(), r_stval());
p->killed = 1;
diff --git a/kernel/vm.c b/kernel/vm.c
index d5a12a0..df0ddde 100644
--- a/kernel/vm.c
+++ b/kernel/vm.c
@@ -303,22 +303,20 @@ uvmcopy(pagetable_t old, pagetable_t new, uint64 sz)
pte_t *pte;
uint64 pa, i;
uint flags;
- char *mem;

for(i = 0; i < sz; i += PGSIZE){
if((pte = walk(old, i, 0)) == 0)
panic("uvmcopy: pte should exist");
if((*pte & PTE_V) == 0)
panic("uvmcopy: page not present");
+ *pte &= ~PTE_W; // set write bit
+ *pte |= PTE_COW; // clear COW bit
pa = PTE2PA(*pte);
flags = PTE_FLAGS(*pte);
- if((mem = kalloc()) == 0)
- goto err;
- memmove(mem, (char*)pa, PGSIZE);
- if(mappages(new, i, PGSIZE, (uint64)mem, flags) != 0){
- kfree(mem);
+ if(mappages(new, i, PGSIZE, pa, flags) != 0){
goto err;
}
+ incrfcount((void*)pa); // increment reference count to pa
}
return 0;

@@ -350,6 +348,9 @@ copyout(pagetable_t pagetable, uint64 dstva, char *src, uint64 len)

while(len > 0){
va0 = PGROUNDDOWN(dstva);
+ if (iscowpage(pagetable, va0)) {
+ cowfault(pagetable, va0);
+ }
pa0 = walkaddr(pagetable, va0);
if(pa0 == 0)
return -1;
diff --git a/time.txt b/time.txt
new file mode 100644
index 0000000..209e3ef
--- /dev/null
+++ b/time.txt
@@ -0,0 +1 @@
+20