Skip to content
5 changes: 4 additions & 1 deletion src/CausalInference.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ export pdag2dag!, pdag_to_dag_meek!, pdag_to_dag_dortarsi!
export count_moves_uniform, randcpdag, UniformScore, causalzigzag, dagzigzag
export keyedreduce

export estimate_equations, generate_data

#include("pinv.jl")
include("graphs.jl")
include("combinations_without.jl")
Expand All @@ -55,6 +57,7 @@ include("sampler.jl")
include("dag_sampler.jl")
include("misc2.jl")
include("exact.jl")
include("equations.jl")
#include("mcs.jl")

# Compatibility with the new "Package Extensions" (https://github.com/JuliaLang/julia/pull/47695)
Expand All @@ -73,4 +76,4 @@ function __init__()

end
end
end # end of module
end # end of module
135 changes: 135 additions & 0 deletions src/equations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
using LinearAlgebra, Graphs, Tables, Random, Statistics

# Define the SCM struct
"""
struct SCM
variables::Vector{<:AbstractString}
coefficients::Vector{<:Vector{<:AbstractFloat}}
residuals::Vector{<:Vector{<:AbstractFloat}}
dag::DiGraph

A struct representing a Structural Causal Model (SCM).

# Fields
- `variables::Vector{<:AbstractString}`: A list of variable names.
- `coefficients::Vector{<:Vector{<:AbstractFloat}}`: A list of coefficient vectors for each variable.
- `residuals::Vector{<:Vector{<:AbstractFloat}}`: A list of residuals for each variable.
- `dag::DiGraph`: The directed graph representing the structure of the SCM.
"""
struct SCM
variables::Vector{String}
coefficients::Vector{Vector{Float64}}
residuals::Vector{Vector{Float64}}
dag::DiGraph
end

function ols_compute(X, y)
X = hcat(ones(size(X, 1)), X)
coef = X \ y
yhat = X * coef
resids = y - yhat
return coef, resids
end

# Function to estimate equations and return an SCM struct
"""
estimate_equations(t, est_g::DiGraph)::SCM

Estimate linear equations from the given table `t` based on the structure of the directed graph `est_g`.

# Arguments
- `t`: A table containing the data for estimation (supports any Tables.jl-compatible format).
- `est_g::DiGraph`: A directed graph representing the structure of the SCM.

# Returns
- `SCM`: A struct containing the estimated variables, their corresponding coefficients, residuals, and the DAG.
"""
function estimate_equations(t, est_g::DiGraph)::SCM
Tables.istable(t) || throw(ArgumentError("Argument supports just Tables.jl types"))

columns = Tables.columns(t)
schema = Tables.schema(t)
variables = propertynames(schema.names)

# Check if it is a DAG
if is_cyclic(est_g)
throw(ArgumentError("The provided graph is cyclic -> est_g::DiGraph should be a DAG."))
end

adj_list = collect(edges(est_g))

var_names = String[]
coefficients = Vector{Vector{Float64}}()
residuals = Vector{Vector{Float64}}()
nodes = variables

for node in nodes
node_index = findfirst(==(node), nodes)
preds = [nodes[e.src] for e in adj_list if e.dst == node_index]

if !isempty(preds)
X = hcat([columns[pred] for pred in preds]...)
y = columns[node]

coef, resid = ols_compute(X, y)

if isa(coef, Vector)
push!(var_names, string(node))
push!(coefficients, coef)
push!(residuals, resid)
else
println("Warning: Coefficients not stored for node $node. Expected vector, got $coef")
end
else
y = columns[node]
intercept = mean(y)
resid = y .- intercept
push!(var_names, string(node))
push!(coefficients, [intercept])
push!(residuals, resid)
end
end

return SCM(var_names, coefficients, residuals, est_g)
end

# Function to generate data from the SCM
"""
generate_data(scm::SCM, N::Int)::NamedTuple

Generate data from the given SCM.

# Arguments
- `scm::SCM`: The structural causal model.
- `N::Int`: The number of data points to generate.

# Returns
- `NamedTuple`: A NamedTuple containing the generated data.
"""
function generate_data(scm::SCM, N::Int)::NamedTuple
columns = Dict{Symbol, Vector{Float64}}()

sorted_indices = topological_sort_by_dfs(scm.dag)
sorted_variables = [scm.variables[i] for i in sorted_indices]
variable_index_map = Dict(variable => index for (index, variable) in enumerate(scm.variables))

for node in sorted_variables
idx = variable_index_map[node]
coef = scm.coefficients[idx]
residual_std = std(scm.residuals[idx])

if length(coef) == 1
columns[Symbol(node)] = coef[1] .+ residual_std * randn(N)
else
preds = [Symbol(scm.variables[i]) for i in inneighbors(scm.dag, idx)]
if isempty(preds)
columns[Symbol(node)] = coef[1] .+ residual_std * randn(N)
else
X = hcat(ones(N), [columns[pred] for pred in preds]...)
columns[Symbol(node)] = X * coef .+ residual_std * randn(N)
end
end
end

return NamedTuple(columns)
end
39 changes: 39 additions & 0 deletions test/equations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
using CausalInference
using Random
Random.seed!(1)

# Generate some sample data to use with the GES algorithm

N = 2000 # number of data points

# define simple linear model with added noise
x = randn(N)
v = x + randn(N)*0.25
w = x + randn(N)*0.25
z = v + w + randn(N)*0.25
s = z + randn(N)*0.25

df = (x=x, v=v, w=w, z=z, s=s)

est_g, score = ges(df; penalty=1.0, parallel=true)


est_dag= pdag2dag!(est_g)

scm= estimate_equations(df,est_dag)

display(scm)

#println(CI.SCM)

df_generated= generate_data(scm, 2000)

println("df: ")

display(df)

println("df_generated: ")



display(df_generated)
1 change: 1 addition & 0 deletions test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,4 @@ include("witness.jl")
include("fci.jl")
include("klentropy.jl")
include("backdoor.jl")
include("equations.jl")