Initialize FastCausalSHAP with data, model, and target variable.
Parameters:
| Name |
Type |
Description |
Default |
data
|
DataFrame
|
The dataset containing features and target variable.
Must not be empty.
|
required
|
model
|
Any
|
A fitted sklearn model with predict() method and feature_names_in_ attribute
Can be a classifier or regressor.
|
required
|
target_variable
|
str
|
The name of the target variable column in the data.
Must exist in data.columns.
|
required
|
Raises:
| Type |
Description |
TypeError
|
If data is not a pandas DataFrame.
|
ValueError
|
If data is empty or target_variable not in data columns.
|
AttributeError
|
If model doesn't have required methods/attributes.
|
Examples:
>>> from sklearn.ensemble import RandomForestRegressor
>>> import pandas as pd
>>>
>>> data = pd.DataFrame({'X1': [1, 2, 3], 'X2': [4, 5, 6], 'Y': [7, 8, 9]})
>>> model = RandomForestRegressor()
>>> model.fit(data[['X1', 'X2']], data['Y'])
>>>
>>> shap = FastCausalSHAP(data, model, 'Y')
Source code in fast_causal_shap/core.py
| def __init__(self, data: pd.DataFrame, model: Any, target_variable: str) -> None:
"""
Initialize FastCausalSHAP with data, model, and target variable.
Parameters
----------
data : pd.DataFrame
The dataset containing features and target variable.
Must not be empty.
model : Any
A fitted sklearn model with predict() method and feature_names_in_ attribute
Can be a classifier or regressor.
target_variable : str
The name of the target variable column in the data.
Must exist in data.columns.
Raises
------
TypeError
If data is not a pandas DataFrame.
ValueError
If data is empty or target_variable not in data columns.
AttributeError
If model doesn't have required methods/attributes.
Examples
--------
>>> from sklearn.ensemble import RandomForestRegressor
>>> import pandas as pd
>>>
>>> data = pd.DataFrame({'X1': [1, 2, 3], 'X2': [4, 5, 6], 'Y': [7, 8, 9]})
>>> model = RandomForestRegressor()
>>> model.fit(data[['X1', 'X2']], data['Y'])
>>>
>>> shap = FastCausalSHAP(data, model, 'Y')
"""
if not isinstance(data, pd.DataFrame):
raise TypeError("data must be a pandas DataFrame")
if data.empty:
raise ValueError("data must not be empty")
if target_variable not in data.columns:
raise ValueError(
f"target_variable '{target_variable}' not found in data columns. "
f"Available columns: {list(data.columns)}"
)
if not hasattr(model, "predict"):
raise AttributeError("model must have a predict method")
if not hasattr(model, "feature_names_in_"):
raise AttributeError(
"model must have 'feature_names_in_' attribute. "
"Ensure the model has been fitted before passing it."
)
self.data: pd.DataFrame = data
self.model: Any = model
self.gamma: Optional[Dict[str, float]] = None
self.target_variable: str = target_variable
self.ida_graph: Optional[nx.DiGraph] = None
self.regression_models: Dict[Tuple[str, Tuple[str, ...]], Tuple[Any, float]] = (
{}
)
self.feature_depths: Dict[str, int] = {}
self.path_cache: Dict[Any, float] = {}
self.causal_paths: Dict[str, List[List[str]]] = {}
|
compute_modified_shap_proba
compute_modified_shap_proba(x: Series, is_classifier: bool = False) -> Dict[str, float]
TreeSHAP-inspired computation using causal paths and dynamic programming.
Source code in fast_causal_shap/core.py
| def compute_modified_shap_proba(
self, x: pd.Series, is_classifier: bool = False
) -> Dict[str, float]:
"""TreeSHAP-inspired computation using causal paths and dynamic programming."""
if self.gamma is None:
raise ValueError(
"Must call load_causal_strengths before computing SHAP values"
)
if not isinstance(x, pd.Series):
raise TypeError(f"x must be a pandas Series, got {type(x).__name__}")
# validate x contains required features
required_features = self.model.feature_names_in_
missing_features = set(required_features) - set(x.index)
if missing_features:
raise ValueError(
f"x is missing required features: {missing_features}. "
f"Required features: {list(required_features)}"
)
features = [col for col in self.data.columns if col != self.target_variable]
phi_causal = {feature: 0.0 for feature in features}
data_without_target = self.data.drop(columns=[self.target_variable])
if is_classifier:
E_fX = self.model.predict_proba(data_without_target)[:, 1].mean()
else:
E_fX = self.model.predict(data_without_target).mean()
x_ordered = x[self.model.feature_names_in_]
if is_classifier:
f_x = self.model.predict_proba(x_ordered.to_frame().T)[0][1]
else:
f_x = self.model.predict(x_ordered.to_frame().T)[0]
sorted_features = sorted(features, key=lambda f: self.feature_depths.get(f, 0))
max_path_length = max(self.feature_depths.values(), default=0)
shapley_weights = {}
for m in range(max_path_length + 1):
for d in range(m + 1, max_path_length + 1):
shapley_weights[(m, d)] = (
factorial(m) * factorial(d - m - 1)
) / factorial(d)
# Track contributions using dynamic programming (EXTEND-like logic in TreeSHAP)
# m_values will accumulate contributions from subsets (use combinatorial logic)
# Essentially, values in m_values[k] represent how many ways there are
# to select k nodes from the path seen so far.
for feature in sorted_features:
if feature not in self.causal_paths:
continue
for path in self.causal_paths[feature]:
path_features = [n for n in path if n != self.target_variable]
d = len(path_features)
m_values = defaultdict(float)
m_values[0] = 1.0
for node in path_features:
if node == feature:
continue
new_m_values: defaultdict[int, float] = defaultdict(float)
for m, val in m_values.items():
new_m_values[m + 1] += val
new_m_values[m] += val
m_values = new_m_values
for m in m_values:
weight = shapley_weights.get((m, d), 0) * self.gamma.get(feature, 0)
delta_v = self._compute_path_delta_v(
feature, path, m, x, is_classifier
)
phi_causal[feature] += weight * delta_v
sum_phi = sum(phi_causal.values())
if sum_phi != 0:
scaling_factor = (f_x - E_fX) / sum_phi
phi_causal = {k: v * scaling_factor for k, v in phi_causal.items()}
return phi_causal
|
compute_v_do
compute_v_do(S: List[str], x_S: Dict[str, float], is_classifier: bool = False) -> float
Compute interventional expectations with caching.
Source code in fast_causal_shap/core.py
| def compute_v_do(
self, S: List[str], x_S: Dict[str, float], is_classifier: bool = False
) -> float:
"""Compute interventional expectations with caching."""
cache_key = (
frozenset(S),
tuple(sorted(x_S.items())) if len(x_S) > 0 else tuple(),
)
if cache_key in self.path_cache:
return self.path_cache[cache_key]
variables_order = self.get_topological_order(S)
sample = {}
for feature in S:
sample[feature] = x_S[feature]
for feature in variables_order:
if feature in S or feature == self.target_variable:
continue
parents = self.get_parents(feature)
parent_values = {
p: x_S[p] if p in S else sample[p]
for p in parents
if p != self.target_variable
}
if not parent_values:
sample[feature] = self.sample_marginal(feature)
else:
sample[feature] = self.sample_conditional(feature, parent_values)
intervened_data = pd.DataFrame([sample])
intervened_data = intervened_data[self.model.feature_names_in_]
if is_classifier:
probas = self.model.predict_proba(intervened_data)[:, 1]
else:
probas = self.model.predict(intervened_data)
result = float(np.mean(probas))
self.path_cache[cache_key] = result
return result
|
get_parents
get_parents(feature: str) -> List[str]
Returns the parent features for a given feature in the causal graph.
Source code in fast_causal_shap/core.py
| def get_parents(self, feature: str) -> List[str]:
"""Returns the parent features for a given feature in the causal graph."""
if self.ida_graph is None:
return []
return list(self.ida_graph.predecessors(feature))
|
get_topological_order
get_topological_order(S: List[str]) -> List[str]
Returns the topological order of variables after intervening on subset S.
Source code in fast_causal_shap/core.py
| def get_topological_order(self, S: List[str]) -> List[str]:
"""Returns the topological order of variables after intervening on subset S."""
if self.ida_graph is None:
return []
G_intervened = self.ida_graph.copy()
for feature in S:
G_intervened.remove_edges_from(list(G_intervened.in_edges(feature)))
missing_nodes = set(self.data.columns) - set(G_intervened.nodes)
G_intervened.add_nodes_from(missing_nodes)
try:
order = list(nx.topological_sort(G_intervened))
except nx.NetworkXUnfeasible:
raise ValueError("The causal graph contains cycles.")
return order
|
is_on_causal_path
is_on_causal_path(feature: str, target_feature: str) -> bool
Check if feature is on any causal path from S to target_feature.
Source code in fast_causal_shap/core.py
| def is_on_causal_path(self, feature: str, target_feature: str) -> bool:
"""Check if feature is on any causal path from S to target_feature."""
if target_feature not in self.causal_paths:
return False
path_features = self.causal_paths[target_feature]
return feature in path_features
|
load_causal_strengths
load_causal_strengths(json_file_path: str) -> Dict[str, float]
Load causal strengths from JSON file and compute gamma values.
Source code in fast_causal_shap/core.py
| def load_causal_strengths(self, json_file_path: str) -> Dict[str, float]:
"""Load causal strengths from JSON file and compute gamma values."""
if not isinstance(json_file_path, str):
raise TypeError("json_file_path must be a string")
import os
if not os.path.isfile(json_file_path):
raise ValueError("json_file_path must be a valid file path")
try:
with open(json_file_path, "r") as f:
causal_effects_list = json.load(f)
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON file: {json_file_path}. Error: {e}")
if not isinstance(causal_effects_list, list):
raise ValueError(
f"JSON file must has a list, got {type(causal_effects_list).__name__}"
)
if not causal_effects_list:
raise ValueError("JSON file contains an empty list")
G = nx.DiGraph()
nodes = list(self.data.columns)
G.add_nodes_from(nodes)
for item in causal_effects_list:
pair = item["Pair"]
mean_causal_effect = item["Mean_Causal_Effect"]
if mean_causal_effect is None:
continue
source, target = pair.split("->")
source = source.strip()
target = target.strip()
G.add_edge(source, target, weight=mean_causal_effect)
self.ida_graph = G.copy()
removed_edges = self.remove_cycles()
if removed_edges:
logger.info(
f"Removed {len(removed_edges)} edges to make the graph acyclic:"
)
for source, target, weight in removed_edges:
logger.info(f" {source} -> {target} (weight: {weight})")
self._compute_feature_depths()
self._compute_causal_paths()
features = self.data.columns.tolist()
beta_dict = {}
for feature in features:
if feature == self.target_variable:
continue
try:
paths = list(
nx.all_simple_paths(G, source=feature, target=self.target_variable)
)
except nx.NetworkXNoPath:
continue
total_effect = 0
for path in paths:
effect = 1
for i in range(len(path) - 1):
edge_weight = G[path[i]][path[i + 1]]["weight"]
effect *= edge_weight
total_effect += effect
if total_effect != 0:
beta_dict[feature] = total_effect
total_causal_effect = sum(abs(beta) for beta in beta_dict.values())
if total_causal_effect == 0:
self.gamma = {k: 0.0 for k in features}
else:
self.gamma = {
k: abs(beta_dict.get(k, 0.0)) / total_causal_effect for k in features
}
return self.gamma
|
remove_cycles
remove_cycles() -> List[Tuple[str, str, float]]
Detects cycles in the graph and removes edges causing cycles.
Returns a list of removed edges.
Source code in fast_causal_shap/core.py
| def remove_cycles(self) -> List[Tuple[str, str, float]]:
"""
Detects cycles in the graph and removes edges causing cycles.
Returns a list of removed edges.
"""
if self.ida_graph is None:
return []
G = self.ida_graph.copy()
removed_edges = []
# Find all cycles in the graph
try:
cycles = list(nx.simple_cycles(G))
except nx.NetworkXNoCycle:
return [] # No cycles found
while cycles:
# Get the current cycle
cycle = cycles[0]
# Find the edge with the smallest weight in the cycle
min_weight = float("inf")
edge_to_remove = None
for i in range(len(cycle)):
source = cycle[i]
target = cycle[(i + 1) % len(cycle)]
if G.has_edge(source, target):
weight = abs(G[source][target]["weight"])
if weight < min_weight:
min_weight = weight
edge_to_remove = (source, target)
if edge_to_remove:
# Remove the edge with the smallest weight
G.remove_edge(*edge_to_remove)
removed_edges.append(
(
edge_to_remove[0],
edge_to_remove[1],
self.ida_graph[edge_to_remove[0]][edge_to_remove[1]]["weight"],
)
)
# Recalculate cycles after removing an edge
try:
cycles = list(nx.simple_cycles(G))
except nx.NetworkXNoCycle:
cycles = [] # No more cycles
else:
break
# Update the graph
self.ida_graph = G
return removed_edges
|
sample_conditional
sample_conditional(feature: str, parent_values: Dict[str, float]) -> float
Sample a value for a feature conditioned on its parent features.
Source code in fast_causal_shap/core.py
| def sample_conditional(
self, feature: str, parent_values: Dict[str, float]
) -> float:
"""Sample a value for a feature conditioned on its parent features."""
effective_parents = [
p for p in self.get_parents(feature) if p != self.target_variable
]
if not effective_parents:
return self.sample_marginal(feature)
model_key = (feature, tuple(sorted(effective_parents)))
if model_key not in self.regression_models:
X = self.data[effective_parents].values
y = self.data[feature].values
reg = LinearRegression()
reg.fit(X, y)
residuals = y - reg.predict(X)
std = residuals.std()
self.regression_models[model_key] = (reg, std)
reg, std = self.regression_models[model_key]
parent_values_array = np.array(
[parent_values[parent] for parent in effective_parents]
).reshape(1, -1)
mean = reg.predict(parent_values_array)[0]
sampled_value = np.random.normal(mean, std)
return sampled_value
|
sample_marginal
sample_marginal(feature: str) -> float
Sample a value from the marginal distribution of the specified feature.
Source code in fast_causal_shap/core.py
| def sample_marginal(self, feature: str) -> float:
"""Sample a value from the marginal distribution of the specified feature."""
return self.data[feature].sample(1).iloc[0]
|