ANDROID: mm: introduce vma refcounting to protect vma during SPF

Current mechanism to stabilize a vma during speculative page fault
handling makes a copy of the faulting vma under RCU protection. This
makes it hard to protect elements which do not belong to the vma but
are used by the page fault handler like vma->vm_file.
The problems is that a copy of the vma can't be used to safely
protect the file attached to the original vma unless the file is
also released after RCU grace period (which is how SPF was designed
originally but that caused performance regression and had to be
changed).
To avoid these complications, introduce vma refcounting to stabilize
and operate on the original vma during page fault handling. Page
fault handler finds the vma and increases its refcount under RCU
protection, vma is freed after RCU grace period, vma->vm_file is
released only after refcount indicates no users. This mechanism
guarantees that once get_vma returns a vma, both the vma itself and
vma->vm_file are stable.
Additional benefits of this patch are: we don't need to copy the vma
and no additional logic is needed to stabilize vma->vm_file.

Bug: 257443051
Change-Id: I59d373926d687fcbd56847a8c3500c43bf1844c8
Signed-off-by: Suren Baghdasaryan 
diff --git a/arch/arm64/mm/fault.c b/arch/arm64/mm/fault.c
index 92cc26d..0218230 100644
--- a/arch/arm64/mm/fault.c
+++ b/arch/arm64/mm/fault.c
@@ -542,9 +542,7 @@ static int __kprobes do_page_fault(unsigned long far, unsigned int esr,
 	unsigned int mm_flags = FAULT_FLAG_DEFAULT;
 	unsigned long addr = untagged_addr(far);
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
-	struct file *orig_file = NULL;
 	struct vm_area_struct *vma;
-	struct vm_area_struct pvma;
 	unsigned long seq;
 #endif
 
@@ -618,38 +616,29 @@ static int __kprobes do_page_fault(unsigned long far, unsigned int esr,
 		count_vm_spf_event(SPF_ABORT_ODD);
 		goto spf_abort;
 	}
-	rcu_read_lock();
-	vma = __find_vma(mm, addr);
-	if (!vma || vma->vm_start > addr) {
-		rcu_read_unlock();
+	vma = get_vma(mm, addr);
+	if (!vma) {
 		count_vm_spf_event(SPF_ABORT_UNMAPPED);
 		goto spf_abort;
 	}
 	if (!vma_can_speculate(vma, mm_flags)) {
-		rcu_read_unlock();
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_NO_SPECULATE);
 		goto spf_abort;
 	}
-	if (vma->vm_file)
-		orig_file = get_file(vma->vm_file);
-	pvma = *vma;
-	rcu_read_unlock();
+
 	if (!mmap_seq_read_check(mm, seq, SPF_ABORT_VMA_COPY)) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		goto spf_abort;
 	}
-	vma = &pvma;
 	if (!(vma->vm_flags & vm_flags)) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_ACCESS_ERROR);
 		goto spf_abort;
 	}
 	fault = do_handle_mm_fault(vma, addr & PAGE_MASK,
 			mm_flags | FAULT_FLAG_SPECULATIVE, seq, regs);
-	if (orig_file)
-		fput(orig_file);
+	put_vma(vma);
 
 	/* Quick path to respond to signals */
 	if (fault_signal_pending(fault, regs)) {
diff --git a/arch/powerpc/mm/fault.c b/arch/powerpc/mm/fault.c
index 0799a058..888c12f 100644
--- a/arch/powerpc/mm/fault.c
+++ b/arch/powerpc/mm/fault.c
@@ -395,8 +395,6 @@ static int ___do_page_fault(struct pt_regs *regs, unsigned long address,
 	vm_fault_t fault, major = 0;
 	bool kprobe_fault = kprobe_page_fault(regs, 11);
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
-	struct file *orig_file = NULL;
-	struct vm_area_struct pvma;
 	unsigned long seq;
 #endif
 
@@ -469,47 +467,37 @@ static int ___do_page_fault(struct pt_regs *regs, unsigned long address,
 		count_vm_spf_event(SPF_ABORT_ODD);
 		goto spf_abort;
 	}
-	rcu_read_lock();
-	vma = __find_vma(mm, address);
-	if (!vma || vma->vm_start > address) {
-		rcu_read_unlock();
+	vma = get_vma(mm, address);
+	if (!vma) {
 		count_vm_spf_event(SPF_ABORT_UNMAPPED);
 		goto spf_abort;
 	}
 	if (!vma_can_speculate(vma, flags)) {
-		rcu_read_unlock();
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_NO_SPECULATE);
 		goto spf_abort;
 	}
-	if (vma->vm_file)
-		orig_file = get_file(vma->vm_file);
-	pvma = *vma;
-	rcu_read_unlock();
+
 	if (!mmap_seq_read_check(mm, seq, SPF_ABORT_VMA_COPY)) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		goto spf_abort;
 	}
-	vma = &pvma;
 #ifdef CONFIG_PPC_MEM_KEYS
 	if (unlikely(access_pkey_error(is_write, is_exec,
 				       (error_code & DSISR_KEYFAULT), vma))) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_ACCESS_ERROR);
 		goto spf_abort;
 	}
 #endif /* CONFIG_PPC_MEM_KEYS */
 	if (unlikely(access_error(is_write, is_exec, vma))) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_ACCESS_ERROR);
 		goto spf_abort;
 	}
 	fault = do_handle_mm_fault(vma, address,
-				   flags | FAULT_FLAG_SPECULATIVE, seq, regs);
-	if (orig_file)
-		fput(orig_file);
+			flags | FAULT_FLAG_SPECULATIVE, seq, regs);
+	put_vma(vma);
 	major |= fault & VM_FAULT_MAJOR;
 
 	if (fault_signal_pending(fault, regs))
diff --git a/arch/x86/mm/fault.c b/arch/x86/mm/fault.c
index 7b05f6d..83e07cb 100644
--- a/arch/x86/mm/fault.c
+++ b/arch/x86/mm/fault.c
@@ -1227,8 +1227,6 @@ void do_user_addr_fault(struct pt_regs *regs,
 	vm_fault_t fault;
 	unsigned int flags = FAULT_FLAG_DEFAULT;
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
-	struct file *orig_file = NULL;
-	struct vm_area_struct pvma;
 	unsigned long seq;
 #endif
 
@@ -1342,38 +1340,30 @@ void do_user_addr_fault(struct pt_regs *regs,
 		count_vm_spf_event(SPF_ABORT_ODD);
 		goto spf_abort;
 	}
-	rcu_read_lock();
-	vma = __find_vma(mm, address);
-	if (!vma || vma->vm_start > address) {
-		rcu_read_unlock();
+	vma = get_vma(mm, address);
+	if (!vma) {
 		count_vm_spf_event(SPF_ABORT_UNMAPPED);
 		goto spf_abort;
 	}
+
 	if (!vma_can_speculate(vma, flags)) {
-		rcu_read_unlock();
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_NO_SPECULATE);
 		goto spf_abort;
 	}
-	if (vma->vm_file)
-		orig_file = get_file(vma->vm_file);
-	pvma = *vma;
-	rcu_read_unlock();
+
 	if (!mmap_seq_read_check(mm, seq, SPF_ABORT_VMA_COPY)) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		goto spf_abort;
 	}
-	vma = &pvma;
 	if (unlikely(access_error(error_code, vma))) {
-		if (orig_file)
-			fput(orig_file);
+		put_vma(vma);
 		count_vm_spf_event(SPF_ABORT_ACCESS_ERROR);
 		goto spf_abort;
 	}
 	fault = do_handle_mm_fault(vma, address,
-				   flags | FAULT_FLAG_SPECULATIVE, seq, regs);
-	if (orig_file)
-		fput(orig_file);
+			flags | FAULT_FLAG_SPECULATIVE, seq, regs);
+	put_vma(vma);
 
 	if (!(fault & VM_FAULT_RETRY))
 		goto done;
diff --git a/include/linux/mm.h b/include/linux/mm.h
index 7710684..21c8954d 100644
--- a/include/linux/mm.h
+++ b/include/linux/mm.h
@@ -253,6 +253,7 @@ void setup_initial_init_mm(void *start_code, void *end_code,
 
 struct vm_area_struct *vm_area_alloc(struct mm_struct *);
 struct vm_area_struct *vm_area_dup(struct vm_area_struct *);
+void vm_area_free_no_check(struct vm_area_struct *);
 void vm_area_free(struct vm_area_struct *);
 
 #ifndef CONFIG_MMU
@@ -685,6 +686,10 @@ static inline void vma_init(struct vm_area_struct *vma, struct mm_struct *mm)
 	memset(vma, 0, sizeof(*vma));
 	vma->vm_mm = mm;
 	vma->vm_ops = &dummy_vm_ops;
+#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
+        /* Start from 0 to use atomic_inc_unless_negative() in get_vma() */
+	atomic_set(&vma->file_ref_count, 0);
+#endif
 	INIT_LIST_HEAD(&vma->anon_vma_chain);
 }
 
@@ -3383,6 +3388,9 @@ static inline bool pte_spinlock(struct vm_fault *vmf)
 	return __pte_map_lock(vmf);
 }
 
+struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr);
+void put_vma(struct vm_area_struct *vma);
+
 #else	/* !CONFIG_SPECULATIVE_PAGE_FAULT */
 
 #define pte_map_lock(___vmf)						\
diff --git a/include/linux/mm_types.h b/include/linux/mm_types.h
index 3142ce9..7127dc5 100644
--- a/include/linux/mm_types.h
+++ b/include/linux/mm_types.h
@@ -419,6 +419,11 @@ struct vm_area_struct {
 #endif
 	struct vm_userfaultfd_ctx vm_userfaultfd_ctx;
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
+	/*
+	 * The name does not reflect the usage and is not renamed to keep
+	 * the ABI intact.
+	 * This is used to refcount VMA in get_vma/put_vma.
+	 */
 	atomic_t file_ref_count;
 #endif
 
diff --git a/kernel/fork.c b/kernel/fork.c
index ad81282..0a3c247 100644
--- a/kernel/fork.c
+++ b/kernel/fork.c
@@ -381,32 +381,41 @@ struct vm_area_struct *vm_area_dup(struct vm_area_struct *orig)
 	return new;
 }
 
-static inline void ____vm_area_free(struct vm_area_struct *vma)
-{
-	kmem_cache_free(vm_area_cachep, vma);
-}
-
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
-static void __vm_area_free(struct rcu_head *head)
+static void __free_vm_area_struct(struct rcu_head *head)
 {
 	struct vm_area_struct *vma = container_of(head, struct vm_area_struct,
 						  vm_rcu);
-	____vm_area_free(vma);
+	kmem_cache_free(vm_area_cachep, vma);
+}
+
+static inline void free_vm_area_struct(struct vm_area_struct *vma)
+{
+	call_rcu(&vma->vm_rcu, __free_vm_area_struct);
+}
+#else
+static inline void free_vm_area_struct(struct vm_area_struct *vma)
+{
+	kmem_cache_free(vm_area_cachep, vma);
 }
 #endif
 
-void vm_area_free(struct vm_area_struct *vma)
+void vm_area_free_no_check(struct vm_area_struct *vma)
 {
 	free_anon_vma_name(vma);
 	if (vma->vm_file)
 		fput(vma->vm_file);
+	free_vm_area_struct(vma);
+}
+
+void vm_area_free(struct vm_area_struct *vma)
+{
 #ifdef CONFIG_SPECULATIVE_PAGE_FAULT
-	if (atomic_read(&vma->vm_mm->mm_users) > 1) {
-		call_rcu(&vma->vm_rcu, __vm_area_free);
+	/* Free only after refcount dropped to negative */
+	if (atomic_dec_return(&vma->file_ref_count) >= 0)
 		return;
-	}
 #endif
-	____vm_area_free(vma);
+	vm_area_free_no_check(vma);
 }
 
 static void account_kernel_stack(struct task_struct *tsk, int account)
diff --git a/mm/memory.c b/mm/memory.c
index 17a03bf..6bc34d7 100644
--- a/mm/memory.c
+++ b/mm/memory.c
@@ -209,6 +209,35 @@ static void check_sync_rss_stat(struct task_struct *task)
 
 #endif /* SPLIT_RSS_COUNTING */
 
+#ifdef CONFIG_SPECULATIVE_PAGE_FAULT
+
+struct vm_area_struct *get_vma(struct mm_struct *mm, unsigned long addr)
+{
+	struct vm_area_struct *vma;
+
+	rcu_read_lock();
+	vma = __find_vma(mm, addr);
+	if (vma) {
+		if (vma->vm_start > addr ||
+		    !atomic_inc_unless_negative(&vma->file_ref_count))
+			vma = NULL;
+	}
+	rcu_read_unlock();
+
+	return vma;
+}
+
+void put_vma(struct vm_area_struct *vma)
+{
+	int new_ref_count;
+
+	new_ref_count = atomic_dec_return(&vma->file_ref_count);
+	if (new_ref_count < 0)
+		vm_area_free_no_check(vma);
+}
+
+#endif	/* CONFIG_SPECULATIVE_PAGE_FAULT */
+
 /*
  * Note: this doesn't free the actual pages themselves. That
  * has been handled earlier when unmapping all the memory regions.