use std::collections::HashSet;
use std::fmt::{Debug, Formatter, Result};
use std::rc::Rc;
use std::time::Instant;
use itertools::Itertools;
use log::*;
use rustc_middle::ty::TyCtxt;
use super::context_strategy::KObjectSensitive;
use super::propagator::propagator::Propagator;
use super::PointerAnalysis;
use crate::graph::func_pag::FuncPAG;
use crate::graph::pag::*;
use crate::graph::call_graph::CSCallGraph;
use crate::mir::call_site::{AssocCallGroup, CSCallSite, CallSite, CallType};
use crate::mir::context::{Context, ContextId};
use crate::mir::function::{FuncId, CSFuncId};
use crate::mir::analysis_context::AnalysisContext;
use crate::mir::path::{Path, CSPath, PathEnum};
use crate::pta::*;
use crate::pta::context_strategy::ContextStrategy;
use crate::util::pta_statistics::ContextSensitiveStat;
use crate::util::{self, chunked_queue, results_dumper};
pub type CallSiteSensitivePTA<'pta, 'tcx, 'compilation> = ContextSensitivePTA<'pta, 'tcx, 'compilation, KCallSiteSensitive>;
pub type ObjectSensitivePTA<'pta, 'tcx, 'compilation> = ContextSensitivePTA<'pta, 'tcx, 'compilation, KObjectSensitive>;
pub struct ContextSensitivePTA<'pta, 'tcx, 'compilation, S: ContextStrategy> {
pub(crate) acx: &'pta mut AnalysisContext<'tcx, 'compilation>,
pub(crate) pt_data: DiffPTDataTy,
pub(crate) pag: PAG<Rc<CSPath>>,
pub call_graph: CSCallGraph,
pub(crate) processed_funcs: HashSet<CSFuncId>,
rf_iter: chunked_queue::IterCopied<CSFuncId>,
addr_edge_iter: chunked_queue::IterCopied<EdgeId>,
pub(crate) inter_proc_edges_queue: chunked_queue::ChunkedQueue<EdgeId>,
assoc_calls: AssocCallGroup<NodeId, CSFuncId, Rc<CSPath>>,
ctx_strategy: S,
}
impl<'pta, 'tcx, 'compilation, S: ContextStrategy> Debug for ContextSensitivePTA<'pta, 'tcx, 'compilation, S> {
fn fmt(&self, f: &mut Formatter<'_>) -> Result {
"ContextSensitivePTA".fmt(f)
}
}
impl<'pta, 'tcx, 'compilation, S: ContextStrategy> ContextSensitivePTA<'pta, 'tcx, 'compilation, S> {
pub fn new(acx: &'pta mut AnalysisContext<'tcx, 'compilation>, ctx_strategy: S) -> Self {
let call_graph = CSCallGraph::new();
let rf_iter = call_graph.reach_funcs_iter();
let pag = PAG::new();
let addr_edge_iter = pag.addr_edge_iter();
ContextSensitivePTA {
acx,
pt_data: DiffPTDataTy::new(),
pag,
call_graph,
processed_funcs: HashSet::new(),
rf_iter,
addr_edge_iter,
inter_proc_edges_queue: chunked_queue::ChunkedQueue::new(),
assoc_calls: AssocCallGroup::new(),
ctx_strategy,
}
}
#[inline]
fn tcx(&self) -> TyCtxt<'tcx> {
self.acx.tcx
}
#[inline]
pub fn get_context_id(&mut self, context: &Rc<Context<S::E>>) -> ContextId {
self.ctx_strategy.get_context_id(context)
}
#[inline]
pub fn get_context_by_id(&self, context_id: ContextId) -> Rc<Context<S::E>> {
self.ctx_strategy.get_context_by_id(context_id)
}
#[inline]
pub fn get_empty_context_id(&mut self) -> ContextId {
self.ctx_strategy.get_empty_context_id()
}
pub fn initialize(&mut self) {
let entry_point = self.acx.entry_point;
let empty_context_id = self.get_empty_context_id();
let entry_func_id = self.acx.get_func_id(entry_point, self.tcx().mk_args(&[]));
self.call_graph.add_node(CSFuncId::new(empty_context_id, entry_func_id));
self.process_reach_funcs();
}
pub fn propagate(&mut self) {
let mut iter_proc_edge_iter = self.inter_proc_edges_queue.iter_copied();
loop {
let mut new_calls: Vec<(Rc<CSCallSite>, FuncId)> = Vec::new();
let mut new_call_instances: Vec<(Rc<CSCallSite>, Rc<CSPath>, FuncId)> = Vec::new();
let mut propagator = Propagator::new(
self.acx,
&mut self.pt_data,
&mut self.pag,
&mut new_calls,
&mut new_call_instances,
&mut self.addr_edge_iter,
&mut iter_proc_edge_iter,
&mut self.assoc_calls,
);
propagator.solve_worklist();
if new_calls.is_empty() && new_call_instances.is_empty() {
break;
} else {
self.process_new_calls(&new_calls);
self.process_new_call_instances(&new_call_instances);
}
}
}
fn process_reach_funcs(&mut self) {
while let Some(func) = self.rf_iter.next() {
if !self.processed_funcs.contains(&func) {
let func_ref = self.acx.get_function_reference(func.func_id);
info!(
"Processing function {:?} {}, context: {:?}",
func.func_id,
func_ref.to_string(),
self.get_context_by_id(func.cid),
);
if self.pag.build_func_pag(self.acx, func.func_id) {
self.add_fpag_edges(func);
self.process_calls_in_fpag(func);
}
}
}
}
pub fn add_fpag_edges(&mut self, func: CSFuncId) {
if self.processed_funcs.contains(&func) {
return;
}
let fpag = unsafe { &*(self.pag.func_pags.get(&func.func_id).unwrap() as *const FuncPAG) };
let edges_iter = fpag.internal_edges_iter();
for (src, dst, kind) in edges_iter {
let cs_src = self.mk_cs_path(src, func.cid);
let cs_dst = self.mk_cs_path(dst, func.cid);
self.pag.add_edge(&cs_src, &cs_dst, kind.clone());
}
if let Some(promoted_funcs) = self.pag.promoted_funcs_map.get(&func.func_id) {
let promoted_funcs = unsafe { &*(promoted_funcs as *const HashSet<FuncId>) };
for promoted_func in promoted_funcs {
let cs_promtoted_func = CSFuncId::new(self.get_empty_context_id(), *promoted_func);
self.add_fpag_edges(cs_promtoted_func);
}
}
if let Some(static_funcs) = self.pag.involved_static_funcs_map.get(&func.func_id) {
let static_funcs = unsafe { &*(static_funcs as *const HashSet<FuncId>) };
for static_func in static_funcs {
let cs_static_func = CSFuncId::new(self.get_empty_context_id(), *static_func);
self.add_fpag_edges(cs_static_func);
}
}
self.processed_funcs.insert(func);
}
fn process_calls_in_fpag(&mut self, func: CSFuncId) {
let fpag = unsafe { &*(self.pag.get_func_pag(&func.func_id).unwrap() as *const FuncPAG) };
for (callsite, callee) in &fpag.static_dispatch_callsites {
let cs_callsite = self.mk_cs_callsite(callsite, func.cid);
self.process_new_call(&cs_callsite, callee);
self.call_graph.set_callsite_type(callsite.into(), CallType::StaticDispatch);
}
for (callsite, callee) in &fpag.special_callsites {
let cs_callsite = self.mk_cs_callsite(callsite, func.cid);
let empty_cid = self.special_callsite_context(&cs_callsite, callee);
let cs_callee = self.mk_cs_func(*callee, empty_cid);
self.call_graph.add_edge(cs_callsite.into(), func, cs_callee);
self.call_graph.set_callsite_type(callsite.into(), CallType::StaticDispatch);
}
for (dyn_fn_obj, callsite) in &fpag.dynamic_fntrait_callsites {
let cs_dyn_fn_obj = self.mk_cs_path(dyn_fn_obj, func.cid);
let cs_callsite = self.mk_cs_callsite(callsite, func.cid);
let dyn_node_id = self.dyn_node_id(&cs_dyn_fn_obj);
self.assoc_calls.add_dynamic_fntrait_call(dyn_node_id, cs_callsite);
self.call_graph.set_callsite_type(callsite.into(), CallType::DynamicFnTrait);
}
for (dyn_var, callsite) in &fpag.dynamic_dispatch_callsites {
let cs_dyn_var = self.mk_cs_path(dyn_var, func.cid);
let cs_callsite = self.mk_cs_callsite(callsite, func.cid);
let dyn_node_id = self.dyn_node_id(&cs_dyn_var);
self.assoc_calls.add_dynamic_dispatch_call(dyn_node_id, cs_callsite);
self.call_graph.set_callsite_type(callsite.into(), CallType::DynamicDispatch);
}
for (fn_ptr, callsite) in &fpag.fnptr_callsites {
let cs_fn_ptr = self.mk_cs_path(fn_ptr, func.cid);
let cs_callsite = self.mk_cs_callsite(callsite, func.cid);
self.assoc_calls.add_fnptr_call(self.pag.get_or_insert_node(&cs_fn_ptr), cs_callsite);
self.call_graph.set_callsite_type(callsite.into(), CallType::FnPtr);
}
}
fn dyn_node_id(&mut self, dyn_obj: &Rc<CSPath>) -> NodeId {
self.pag.get_or_insert_node(dyn_obj)
}
fn process_new_call(&mut self, callsite: &Rc<CSCallSite>, callee: &FuncId) {
let callee_def_id = self.acx.get_function_reference(*callee).def_id;
if util::has_self_parameter(self.tcx(), callee_def_id) {
if util::has_self_ref_parameter(self.tcx(), callee_def_id) {
if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, None) {
let cs_callee = CSFuncId::new(callee_cid, *callee);
self.add_call_edge(callsite, &cs_callee);
}
let self_ref: &Rc<CSPath> = callsite.args.get(0).expect("invalid arguments");
let self_ref_id = self.pag.get_or_insert_node(self_ref);
self.assoc_calls.add_static_dispatch_instance_call(self_ref_id, callsite.clone(), *callee);
} else { let instance = callsite.args.get(0).expect("invalid arguments");
if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, Some(instance)) {
let cs_callee = CSFuncId::new(callee_cid, *callee);
self.add_call_edge(callsite, &cs_callee);
}
}
} else {
let callee_cid = self.ctx_strategy.new_static_call_context(callsite);
let cs_callee = CSFuncId::new(callee_cid, *callee);
self.add_call_edge(callsite, &cs_callee);
}
}
fn special_callsite_context(&mut self, callsite: &Rc<CSCallSite>, _callee: &FuncId) -> ContextId {
self.ctx_strategy.new_static_call_context(callsite)
}
fn process_new_calls(&mut self, new_calls: &Vec<(Rc<CSCallSite>, FuncId)>) {
for (callsite, callee_id) in new_calls {
self.process_new_call(callsite, callee_id);
}
self.process_reach_funcs();
}
fn process_new_call_instances(&mut self, new_call_instances: &Vec<(Rc<CSCallSite>, Rc<CSPath>, FuncId)>) {
for (callsite, instance, callee_id) in new_call_instances {
if let Some(callee_cid) = self.ctx_strategy.new_instance_call_context(callsite, Some(instance)) {
let cs_callee = CSFuncId::new(callee_cid, *callee_id);
self.add_call_edge(callsite, &cs_callee);
}
}
self.process_reach_funcs();
}
fn add_call_edge(&mut self, callsite: &Rc<CSCallSite>, callee: &CSFuncId) {
let caller = callsite.func;
if !self.call_graph.add_edge(callsite.into(), caller, *callee) {
return;
}
let new_inter_proc_edges = self.pag.add_inter_procedural_edges(self.acx, callsite, *callee);
for edge in new_inter_proc_edges {
self.inter_proc_edges_queue.push(edge);
}
}
fn mk_cs_path(&mut self, path: &Rc<Path>, cid: ContextId) -> Rc<CSPath> {
match path.value() {
PathEnum::Parameter { .. }
| PathEnum::LocalVariable { .. }
| PathEnum::ReturnValue { .. }
| PathEnum::Auxiliary { .. }
| PathEnum::QualifiedPath { .. }
| PathEnum::OffsetPath { .. } => {
CSPath::new_cs_path(cid, path.clone())
}
PathEnum::HeapObj { .. } => {
CSPath::new_cs_path(cid, path.clone())
}
PathEnum::Constant
| PathEnum::StaticVariable { .. }
| PathEnum::PromotedConstant { .. }
| PathEnum::Function(..)
| PathEnum::PromotedStrRefArray
| PathEnum::PromotedArgumentV1Array
| PathEnum::Type(..) => {
let empty_cid = self.get_empty_context_id();
CSPath::new_cs_path(empty_cid, path.clone())
}
}
}
fn mk_cs_func(&mut self, func_id: FuncId, cid: ContextId) -> CSFuncId {
CSFuncId { cid, func_id }
}
fn mk_cs_callsite(&mut self, callsite: &Rc<CallSite>, cid: ContextId) -> Rc<CSCallSite> {
Rc::new(CSCallSite::new(
CSFuncId { cid, func_id: callsite.func },
callsite.location,
callsite
.args
.iter()
.map(|arg| self.mk_cs_path(arg, cid))
.collect_vec(),
self.mk_cs_path(&callsite.destination, cid),
))
}
#[inline]
pub fn get_pt_data(&self) -> &DiffPTDataTy {
&self.pt_data
}
pub fn finalize(&self) {
results_dumper::dump_results(self.acx, &self.call_graph, &self.pt_data, &self.pag);
let pta_stat = ContextSensitiveStat::new(self);
pta_stat.dump_stats();
}
}
impl<'pta, 'tcx, 'compilation, S: ContextStrategy> PointerAnalysis<'tcx, 'compilation>
for ContextSensitivePTA<'pta, 'tcx, 'compilation, S>
{
fn analyze(&mut self) {
let now = Instant::now();
self.initialize();
self.propagate();
let elapsed = now.elapsed();
info!("Context-sensitive PTA completed.");
info!(
"Analysis time: {}",
humantime::format_duration(elapsed).to_string()
);
self.finalize();
}
}