use petgraph::graph::{DefaultIx, EdgeIndex, NodeIndex};
use petgraph::Graph;
use std::collections::hash_map::Entry;
use std::collections::{HashMap, HashSet};
use std::fmt::{self, Debug};
use std::hash::Hash;
use crate::mir::analysis_context::AnalysisContext;
use crate::mir::call_site::{BaseCallSite, CallType, CSBaseCallSite};
use crate::mir::function::{FuncId, CSFuncId};
use crate::util::chunked_queue::{self, ChunkedQueue};
use crate::util::dot::Dot;
pub type CGNodeId = NodeIndex<DefaultIx>;
pub type CGEdgeId = EdgeIndex<DefaultIx>;
pub type CSCallGraph = CallGraph<CSFuncId, CSBaseCallSite>;
pub trait CGFunction: Copy + Clone + PartialEq + Eq + Hash + Debug {
fn dot_fmt(&self, acx: &AnalysisContext, f: &mut fmt::Formatter) -> fmt::Result;
}
impl CGFunction for FuncId {
fn dot_fmt(&self, acx: &AnalysisContext, f: &mut fmt::Formatter) -> fmt::Result {
f.write_fmt(format_args!(
"{}",
acx.get_function_reference(*self).to_string()
))
}
}
impl CGFunction for CSFuncId {
fn dot_fmt(&self, acx: &AnalysisContext, f: &mut fmt::Formatter) -> fmt::Result {
f.write_fmt(format_args!(
"{}",
acx.get_function_reference(self.func_id).to_string(),
))
}
}
pub trait CGCallSite: Copy + Clone + PartialEq + Eq + Hash + Debug {
fn dot_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result;
}
impl CGCallSite for BaseCallSite {
fn dot_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_fmt(format_args!("{:?}", self.location))
}
}
impl CGCallSite for CSBaseCallSite {
fn dot_fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_fmt(format_args!("{:?}", self.location))
}
}
#[derive(Debug)]
pub struct CallGraphNode<F: CGFunction> {
pub(crate) func: F,
}
impl<F: CGFunction> CallGraphNode<F> {
pub fn new(func: F) -> Self {
CallGraphNode { func }
}
}
#[derive(Debug)]
pub struct CallGraphEdge<S: CGCallSite> {
pub(crate) callsite: S,
}
impl<S: CGCallSite> CallGraphEdge<S> {
pub fn new(callsite: S) -> Self {
CallGraphEdge { callsite }
}
}
pub struct CallGraph<F: CGFunction, S: CGCallSite> {
pub graph: Graph<CallGraphNode<F>, CallGraphEdge<S>>,
pub func_nodes: HashMap<F, CGNodeId>,
pub callsite_to_edges: HashMap<S, HashSet<CGEdgeId>>,
pub(crate) callsite_to_type: HashMap<BaseCallSite, CallType>,
pub(crate) reach_funcs: ChunkedQueue<F>,
}
impl<F: CGFunction, S: CGCallSite> CallGraph<F, S> {
pub fn new() -> Self {
CallGraph {
graph: Graph::<CallGraphNode<F>, CallGraphEdge<S>>::new(),
func_nodes: HashMap::new(),
callsite_to_edges: HashMap::new(),
callsite_to_type: HashMap::new(),
reach_funcs: ChunkedQueue::new(),
}
}
pub fn add_node(&mut self, func: F) {
if let Entry::Vacant(e) = self.func_nodes.entry(func) {
let node = CallGraphNode::new(func);
let node_id = self.graph.add_node(node);
e.insert(node_id);
self.add_reach_func(func);
}
}
fn get_or_insert_node(&mut self, func: F) -> CGNodeId {
match self.func_nodes.entry(func) {
Entry::Occupied(o) => o.get().to_owned(),
Entry::Vacant(v) => {
self.reach_funcs.push(func);
let node_id = self.graph.add_node(CallGraphNode::new(func));
*v.insert(node_id)
}
}
}
pub fn set_callsite_type(&mut self, callsite: BaseCallSite, call_type: CallType) {
self.callsite_to_type.insert(callsite, call_type);
}
pub fn get_callsite_type(&self, callsite: &BaseCallSite) -> Option<&CallType> {
self.callsite_to_type.get(&callsite)
}
pub fn get_callee_id_of_edge(&self, edge_id: EdgeIndex) -> Option<F> {
if let Some((_, callee_node)) = self.edge_endpoints(edge_id) {
if let Some(node) = self.graph.node_weight(callee_node) {
return Some(node.func);
}
return None;
}
return None;
}
pub fn edge_endpoints(&self, edge_id: EdgeIndex) -> Option<(CGNodeId, CGNodeId)> {
self.graph.edge_endpoints(edge_id)
}
pub fn get_callees(&self, callsite: &S) -> HashSet<F> {
if let Some(edges) = self.callsite_to_edges.get(callsite) {
edges
.iter()
.filter_map(|edge_id| match self.graph.edge_endpoints(*edge_id) {
Some((_, target)) => Some(self.graph.node_weight(target).unwrap().func),
None => None,
})
.collect::<HashSet<F>>()
} else {
HashSet::new()
}
}
pub fn has_edge(&self, callsite: &S, callee_id: F) -> bool {
let callees = self.get_callees(callsite);
callees.contains(&callee_id)
}
pub fn add_edge(&mut self, callsite: S, caller_id: F, callee_id: F) -> bool {
let caller_node = self.get_or_insert_node(caller_id);
let callee_node = self.get_or_insert_node(callee_id);
let callees = self.get_callees(&callsite);
if !callees.contains(&callee_id) {
let edge = CallGraphEdge::new(callsite);
let edge_id = self.graph.add_edge(caller_node, callee_node, edge);
self.callsite_to_edges
.entry(callsite)
.or_default()
.insert(edge_id);
true
} else {
false
}
}
pub fn add_reach_func(&mut self, func: F) {
self.reach_funcs.push(func);
}
pub fn reach_funcs_iter(&self) -> chunked_queue::IterCopied<F> {
self.reach_funcs.iter_copied()
}
pub fn to_dot(&self, acx: &AnalysisContext, dot_path: &std::path::Path) {
let node_fmt = |node: &CallGraphNode<F>, f: &mut fmt::Formatter| -> fmt::Result {
node.func.dot_fmt(acx, f)
};
let edge_fmt = |edge: &CallGraphEdge<S>, f: &mut fmt::Formatter| -> fmt::Result {
edge.callsite.dot_fmt(f)
};
let output = format!(
"{:?}",
Dot::with_graph_fmt(&self.graph, &[], &node_fmt, &edge_fmt)
);
match std::fs::write(dot_path, output) {
Ok(_) => (),
Err(e) => panic!("Failed to write dot file output: {:?}", e),
};
}
}