Skip to content

Commit fa17e0f

Browse files
author
Moritz Schauer
authored
Merge pull request #166 from mschauer/do
Add do operator, make new version, improve API of pcalg.
2 parents 42359ae + a9d25ff commit fa17e0f

12 files changed

Lines changed: 107 additions & 37 deletions

File tree

Project.toml

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -26,54 +26,56 @@ Tables = "bd369af6-aec1-5ad0-b16a-f7cc5008161c"
2626
TabularDisplay = "3eeacb1d-13c2-54cc-9b18-30c86af3cadb"
2727
ThreadsX = "ac1d9e8a-700a-412c-b207-f0111f4b6c0d"
2828

29+
[weakdeps]
30+
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
31+
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
32+
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
33+
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
34+
35+
[extensions]
36+
GraphMakieExt = "GraphMakie"
37+
GraphRecipesExt = ["GraphRecipes", "Plots"]
38+
TikzGraphsExt = "TikzGraphs"
39+
2940
[compat]
3041
Combinatorics = "1.0"
3142
DelimitedFiles = "1.6, 1.7, 1.8, 1.9"
3243
Distances = "0.8, 0.9, 0.10"
3344
Distributions = "0.22, 0.23, 0.24, 0.25"
3445
GraphMakie = "0.5"
3546
GraphRecipes = "0.5"
36-
Graphs = "1.5, 1.6, 1.7, 1.8, 1.9"
37-
LRUCache = "1.4, 1.5"
47+
Graphs = "1.5, 1.6, 1.7, 1.8, 1.9, 1.10, 1.11, 1.12"
48+
LRUCache = "1.4, 1.5, 1.6"
3849
LinearAlgebra = "1"
3950
LinkedLists = "0.1"
4051
LogarithmicNumbers = "1.4"
4152
Memoization = "0.2"
42-
MetaGraphs = "0.7"
53+
MetaGraphs = "0.7, 0.8"
4354
NearestNeighbors = "0.4"
44-
OffsetArrays = "1.12"
55+
OffsetArrays = "1.12, 1.13, 1.14"
56+
OrderedCollections = "1.7.0"
4557
Plots = "1"
4658
PrecompileTools = "1.1, 1.2"
47-
ProgressMeter = "1.9"
59+
ProgressMeter = "1.9, 1.10"
4860
Random = "1"
4961
Requires = "1.3"
50-
SpecialFunctions = "1.8, 2.0, 2.1, 2.2, 2.3"
62+
SpecialFunctions = "1.8, 2.0, 2.1, 2.2, 2.3, 2.4"
5163
Statistics = "1"
52-
Tables = "1.6, 1.7, 1.8, 1.9, 1.10, 1.11"
64+
Tables = "1.6, 1.7, 1.8, 1.9, 1.10, 1.11, 1.12"
5365
TabularDisplay = "1.2"
5466
ThreadsX = "0.1"
5567
TikzGraphs = "1.3, 1.4"
5668
julia = "1.6, 1.7, 1.8, 1.9, 1.10"
5769

58-
[extensions]
59-
GraphMakieExt = "GraphMakie"
60-
GraphRecipesExt = ["GraphRecipes", "Plots"]
61-
TikzGraphsExt = "TikzGraphs"
62-
6370
[extras]
6471
DelimitedFiles = "8bb1440f-4735-579b-a4ab-409b98df4dab"
6572
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
6673
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
74+
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
6775
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
6876
StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
6977
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
7078
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
7179

7280
[targets]
73-
test = ["Test", "StatsBase", "DelimitedFiles"]
74-
75-
[weakdeps]
76-
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
77-
GraphRecipes = "bd48cda9-67a9-57be-86fa-5b3c104eda73"
78-
Plots = "91a5bcdd-55d7-5caf-9e0b-520d859cae80"
79-
TikzGraphs = "b4f28e30-c73f-5eaf-a395-8a9db949a742"
81+
test = ["Test", "StatsBase", "DelimitedFiles", "OrderedCollections"]

docs/src/library.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ CausalInference.meek_rule4
3333
pdag2dag!
3434
pdag_to_dag_dortarsi!
3535
pdag_to_dag_meek!
36+
CausalInference.do!
3637
```
3738

3839
## PC algorithm

src/CausalInference.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ include("combinations_without.jl")
4141
include("klentropy.jl")
4242
include("skeleton.jl")
4343
include("pdag.jl")
44+
include("do.jl")
4445
include("dsep.jl")
4546
include("meek.jl")
4647
include("pc.jl")

src/do.jl

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
2+
""""
3+
do!(g, v)
4+
5+
Graphical do operator, removes all
6+
incoming edges to vertex `v` in a DiGraph `g`.
7+
Returns the modified graph.
8+
"""
9+
function do!(g::DiGraph{T}, v::T) where {T}
10+
for u in collect(inneighbors(g, v))
11+
rem_edge!(g, u, v)
12+
end
13+
return g
14+
end

src/pc.jl

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,9 +114,30 @@ function _vskel(n::V, I, par...) where {V}
114114
dg
115115
end
116116

117+
"""
118+
pcalg(g, I; stable=true)
119+
120+
Perform the PC algorithm for a set of 1:n variables using the tests
121+
122+
I(u, v, [s1, ..., sn])
123+
124+
Use `IClosure(I, args)` to wrap a function f with signature
125+
126+
f(u, v, [s1, ..., sn], par...)
127+
128+
Returns the CPDAG as DiGraph. By default uses a stable and threaded versions
129+
of the skeleton algorithm. (This is the most recent interface)
130+
"""
131+
function pcalg(g::Graph, I; stable=true)
132+
g, S = stable ? skeleton_stable(g, I) : skeleton(g, I)
133+
dg = DiGraph(g) # use g to keep track of unoriented edges
134+
_, dg = orient_unshielded(g, dg, S)
135+
meek_rules!(dg)
136+
dg
137+
end
138+
117139
"""
118140
pcalg(n::V, I, par...; stable=true)
119-
pcalg(g, I, par...; stable=true)
120141
121142
Perform the PC algorithm for a set of 1:n variables using the tests
122143

src/skeleton.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,7 @@ function skeleton(g::SimpleGraph, I)
8383
@label nextedge
8484
end
8585
d = d + 1
86-
if isdone
86+
if isdone || d > n-2
8787
return g, S
8888
end
8989
end
@@ -136,7 +136,7 @@ function skeleton_stable(g, I)
136136
end
137137
empty!(remove)
138138
d = d + 1
139-
if isdone
139+
if isdone || d > n-2
140140
return g, S
141141
end
142142
end

test/dag_sampler.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
using Random, CausalInference, Statistics, Test, Graphs, LinearAlgebra
2+
using OrderedCollections
3+
sort2(d; kwargs...) = sort!(OrderedDict(d); kwargs...)
24
@testset "Dag-Zig-Zag" begin
3-
45
Random.seed!(1)
56

67
N = 100 # number of data points
@@ -28,12 +29,12 @@ using Random, CausalInference, Statistics, Test, Graphs, LinearAlgebra
2829
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
2930
@test score_dag(graphs[end], score) scores[end] + score_dag(DiGraph(n), score)
3031
@test !any(Graphs.is_cyclic.(graphs))
31-
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
32+
posterior = sort2(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
3233
logΠ = map(g->score_dag(digraph(g, n), score), collect(keys(posterior)))
3334
Π = normalize(exp.(logΠ .- maximum(logΠ) ), 1)
3435
@test norm(collect(values(posterior)) - Π, 1) < 30/sqrt(iterations)
3536

36-
posterior = sort(keyedreduce(+, vpairs.(cpdag.(graphs)), ws); byvalue=true, rev=true)
37+
posterior = sort2(keyedreduce(+, vpairs.(cpdag.(graphs)), ws); byvalue=true, rev=true)
3738
@test first(posterior).first == [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
3839
end
3940
end
@@ -50,7 +51,7 @@ end #testset
5051
n = 3
5152
gs = dagzigzag(n; iterations);
5253
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs);
53-
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
54+
posterior = sort2(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
5455
@test length(posterior) == 25
5556
@test maximum(values(posterior)) < 1/25 + 0.01
5657
@test minimum(values(posterior)) > 1/25 - 0.01

test/equations.jl

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,24 +17,21 @@ df = (x=x, v=v, w=w, z=z, s=s)
1717

1818
est_g, score = ges(df; penalty=1.0, parallel=true)
1919

20-
2120
est_dag= pdag2dag!(est_g)
2221

2322
scm= estimate_equations(df,est_dag)
2423

25-
display(scm)
24+
#display(scm)
2625

2726
#println(CI.SCM)
2827

29-
df_generated= generate_data(scm, 2000)
28+
df_generated = generate_data(scm, 2000)
3029

3130
println("df: ")
3231

33-
display(df)
32+
#display(df)
3433

3534
println("df_generated: ")
3635

37-
38-
39-
display(df_generated)
36+
#display(df_generated)
4037

test/runtests.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
using CausalInference
22
using Graphs
33
using Test
4-
4+
include("zintervention.jl")
55
include("exact.jl")
66
include("operators.jl")
77
include("ges.jl")

test/sampler.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
using Random, CausalInference, Statistics, Test, Graphs, LinearAlgebra
2+
using OrderedCollections
3+
sort2(d; kwargs...) = sort!(OrderedDict(d); kwargs...)
24
@testset "Zig-Zag" begin
35
Random.seed!(1)
46

@@ -21,7 +23,7 @@ using Random, CausalInference, Statistics, Test, Graphs, LinearAlgebra
2123
score = GaussianScore(C, N, penalty)
2224
gs = causalzigzag(n; score, κ, iterations)
2325
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs)
24-
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
26+
posterior = sort2(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
2527

2628
# maximum aposteriori estimate
2729
@test first(posterior).first == [1=>2, 1=>3, 2=>1, 2=>4, 3=>1, 3=>4, 4=>5]
@@ -46,14 +48,14 @@ end #testset
4648

4749
gs = causalzigzag(n; iterations);
4850
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs);
49-
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
51+
posterior = sort2(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
5052
@test length(posterior) == m
5153
@test norm(values(posterior) .- 1/m, 1)/2 < 0.05
5254
T_skew = sum(τs)
5355

5456
gs = causalzigzag(n; iterations, σ=1.0, ρ=0.0);
5557
graphs, graph_pairs, hs, τs, ws, ts, scores = CausalInference.unzipgs(gs);
56-
posterior = sort(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
58+
posterior = sort2(keyedreduce(+, graph_pairs, ws); byvalue=true, rev=true)
5759
@test length(posterior) == m
5860
@test norm(values(posterior) .- 1/m, 1)/2 < 0.05
5961
T = sum(τs)

0 commit comments

Comments
 (0)