Summary: Scroll down for reproducible example which should run-from-scratch in Julia if you have the packages specified in the using lines. (Note: the ODE has a complex, re-usable structure which is specified in a Gist which is downloaded/included by the script.)
Background: I have to repeatedly solve a large system of ODEs for different initial conditions vectors. In the example below, it is 127 states/ODEs, but it could easily be 1000-2000. I will have to run these 100s-1000s of times for inference, so speed is essential.
The Puzzle: The short version is that, for the serial functions, the #simd version is much faster than the "plain", non-#simd version. But for the parallel versions, the #simd version is much slower -- plus, in this case, the answer, sum_of_solutions, is variable and wrong.
I have this set up so that Julia is started with JULIA_NUM_THREADS=auto julia, in my case this creates 8 cores for 8 threads. Then, I make sure I never have more than 8 jobs spawned at once.
The different calculation times: (runtime, then sum_of_ODE_solutions)
# Output is (runtime, sum_of_solutions)
As you can see, while serial #simd gets the calculation speed down to 0.046 seconds, and while parallel plain is 2.5 times faster than serial plain, when I combine parallelization with the #simd function I get runtimes of 140 seconds, and with variable & wrong answers to boot! Literally the only difference between the two parallelizng functions is using core_op_plain versus core_op_simd for the core ODE solving operation.
It seems like #simd and #spawn must be conflicting somehow? I have the parallelizing function set up to never employ more than the number of CPU threads available. (8 on my machine.)
I am still learning Julia so there is the chance that some smallish change could isolate the #simd calculations and prevent conflicts across threads (if that is what is happening). Any help is very much appreciated!
PS: Reproducible Example. The code below should provide a reproducible example on any Julia session running multiple cores. I also have my versioninfo() etc.:
My setup is:
Julia Version 1.7.3
Commit 742b9abb4d (2022-05-06 12:58 UTC)
Platform Info:
OS: macOS (x86_64-apple-darwin21.4.0)
CPU: Intel(R) Xeon(R) CPU E5-2697 v2 # 2.70GHz
LIBM: libopenlibm
LLVM: libLLVM-12.0.1 (ORCJIT, ivybridge)
# Startup notes
# "If $JULIA_NUM_THREADS is set to auto, then the number of threads will be set to the number of CPU threads."
JULIA_NUM_THREADS=auto julia --startup-file=no
Threads.nthreads(): 8 # Number of CPU threads
using LinearAlgebra # for "I" in: Matrix{Float64}(I, 2, 2)
using Sundials # for CVODE_BDF
using Statistics # for mean(), max()
using DataFrames # for e.g. DataFrame()
using Dates # for e.g. DateTime,
using DifferentialEquations # for ODEProblem
using BenchmarkTools # for #benchmark
using Distributed # for workers
# Check that you have multiple threads
numthreads = Base.Threads.nthreads()
# Download & include the pre-saved model structure/rates (all precalculated for speed; 1.8 MB)
url = ""
download(url, "model_p_object.jl")
# Load the ODE functions
url = ""
download(url, "simd_vs_spawn_setup_v2.jl")
# Load the pre-saved model structure/rates (all precalculated for speed; 1.8 MB)
p_Es_v5 = load_ps_127();
# Set up output object
numstates = 127
number_of_solves = 10
solve_results1 = Array{Float64, 2}(undef, number_of_solves, numstates)
solve_results1 .= 0.0
solve_results2 = Array{Float64, 2}(undef, number_of_solves, numstates)
solve_results2 .= 0.0
# Precalculate the Es for use in the Ds
Es_tspan = (0.0, 60.0)
prob_Es_v7 = DifferentialEquations.ODEProblem(Es_v7_simd_sums, p_Es_v5.uE, Es_tspan, p_Es_v5);
sol_Es_v7 = solve(prob_Es_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=true,
abstol=1e-12, reltol=1e-9);
p_Ds_v7 = (n=p_Es_v5.n, params=p_Es_v5.params, p_indices=p_Es_v5.p_indices, p_TFs=p_Es_v5.p_TFs, uE=p_Es_v5.uE, terms=p_Es_v5.terms, sol_Es_v5=sol_Es_v7);
# Set up ODE inputs
u = collect(repeat([0.0], numstates));
u[2] = 1.0
du = similar(u)
du .= 0.0
p = p_Ds_v7;
t = 1.0
# ODE functions to integrate (single-step; ODE solvers will run this many many times)
#time Ds_v5_tmp(du,u,p,t)
#time Ds_v5_tmp(du,u,p,t)
#time Ds_v7_simd_sums(du,u,p,t)
#time Ds_v7_simd_sums(du,u,p,t)
##btime Ds_v5_tmp(du,u,p,t)
# 7.819 ms (15847 allocations: 1.09 MiB)
##btime Ds_v7_simd_sums(du,u,p,t)
# 155.858 μs (3075 allocations: 68.66 KiB)
tspan = (0.0, 1.0)
prob_Ds_v7 = DifferentialEquations.ODEProblem(Ds_v7_simd_sums, p_Ds_v7.uE, tspan, p_Ds_v7);
sol_Ds_v7 = solve(prob_Ds_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
# This is the core operation; plain version (no #simd)
function core_op_plain(u, tspan, p_Ds_v7)
prob_Ds_v5 = DifferentialEquations.ODEProblem(Ds_v5_tmp, u.+0.0, tspan, p_Ds_v7);
sol_Ds_v5 = solve(prob_Ds_v5, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
return sol_Ds_v5
# This is the core operation; #simd version
function core_op_simd(u, tspan, p_Ds_v7)
prob_Ds_v7 = DifferentialEquations.ODEProblem(Ds_v7_simd_sums, u.+0.0, tspan, p_Ds_v7);
sol_Ds_v7 = solve(prob_Ds_v7, CVODE_BDF(linear_solver=:GMRES), save_everystep=false, abstol=1e-12, reltol=1e-9);
return sol_Ds_v7
#time core_op_plain(u, tspan, p_Ds_v7);
#time core_op_plain(u, tspan, p_Ds_v7);
#time core_op_simd(u, tspan, p_Ds_v7);
#time core_op_simd(u, tspan, p_Ds_v7);
function serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=10)
start_time =
for i in 1:number_of_solves
# Temporary u
solve_results1[i,:] .= 0.0
# Change the ith state from 0.0 to 1.0
solve_results1[i,i] = 1.0
sol_Ds_v7 = core_op_plain(solve_results1[i,:], tspan, p_Ds_v7)
solve_results1[i,:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)]
# print("\n")
# print(round.(sol_Ds_v7[length(sol_Ds_v7)], digits=3))
end_time =
duration = (end_time - start_time).value / 1000.0
sum_of_solutions = sum(sum.(solve_results1))
return (duration, sum_of_solutions)
function serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=10)
start_time =
for i in 1:number_of_solves
# Temporary u
solve_results1[i,:] .= 0.0
# Change the ith state from 0.0 to 1.0
solve_results1[i,i] = 1.0
sol_Ds_v7 = core_op_simd(solve_results1[i,:], tspan, p_Ds_v7)
solve_results1[i,:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)]
# print("\n")
# print(round.(sol_Ds_v7[length(sol_Ds_v7)], digits=3))
end_time =
duration = (end_time - start_time).value / 1000.0
sum_of_solutions = sum(sum.(solve_results1))
return (duration, sum_of_solutions)
# Output is (runtime, sum_of_solutions)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_plain_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (1.1, 8.731365050398926)
# (0.878, 8.731365050398926)
# (0.898, 8.731365050398926)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
serial_with_simd_v7(tspan, p_Ds_v7, solve_results1; number_of_solves=number_of_solves)
# (duration, sum_of_solutions)
# (0.046, 8.731365050398928)
# (0.042, 8.731365050398928)
# (0.046, 8.731365050398928)
using Distributed
function parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=10)
start_time =
number_of_threads = Base.Threads.nthreads()
curr_numthreads = Base.Threads.nthreads()
# Individual ODE solutions will occur over different timeperiods,
# initial values, and parameters. We'd just like to load up the
# cores for the first jobs in the list, then add jobs as earlier
# jobs finish.
tasks = Any[]
tasks_started_TF = Bool[]
tasks_fetched_TF = Bool[]
task_numbers = Any[]
task_inc = 0
are_we_done = false
current_running_tasks = Any[]
# List the tasks
for i in 1:number_of_solves
# Temporary u
solve_results2[i,:] .= 0.0
# Change the ith state from 0.0 to 1.0
solve_results2[i,i] = 1.0
task_inc = task_inc + 1
push!(tasks_started_TF, false) # Add a "false" to tasks_started_TF
push!(tasks_fetched_TF, false) # Add a "false" to tasks_fetched_TF
push!(task_numbers, task_inc)
# Total number of tasks
num_tasks = length(tasks_fetched_TF)
iteration_number = 0
while(are_we_done == false)
iteration_number = iteration_number+1
# Launch tasks when thread (core) is available
for j in 1:num_tasks
if (tasks_fetched_TF[j] == false)
if (tasks_started_TF[j] == false) && (curr_numthreads > 0)
# Start a task
push!(tasks, Base.Threads.#spawn core_op_plain(solve_results2[j,:], tspan, p_Ds_v7));
curr_numthreads = curr_numthreads-1;
tasks_started_TF[j] = true;
push!(current_running_tasks, task_numbers[j])
# Check for finished tasks
tasks_to_check_TF = ((tasks_started_TF.==true) .+ (tasks_fetched_TF.==false)).==2
if sum(tasks_to_check_TF .== true) > 0
for k in 1:sum(tasks_to_check_TF)
if (tasks_fetched_TF[current_running_tasks[k]] == false)
if (istaskstarted(tasks[k]) == true) && (istaskdone(tasks[k]) == true)
sol_Ds_v7 = fetch(tasks[k]);
solve_results2[current_running_tasks[k],:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)].+0.0
tasks_fetched_TF[current_running_tasks[k]] = true
current_tasknum = current_running_tasks[k]
deleteat!(tasks, k)
deleteat!(current_running_tasks, k)
curr_numthreads = curr_numthreads+1;
print("\nFinished task #")
print(", current task k=")
break # break out of this loop, since you have modified current_running_tasks
are_we_done = sum(tasks_fetched_TF) == length(tasks_fetched_TF)
# Test for concluding the while loop
are_we_done && break
end # END while(are_we_done == false)
end_time =
duration = (end_time - start_time).value / 1000.0
sum_of_solutions = sum(sum.(solve_results2))
return (duration, sum_of_solutions)
function parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=10)
start_time =
number_of_threads = Base.Threads.nthreads()
curr_numthreads = Base.Threads.nthreads()
# Individual ODE solutions will occur over different timeperiods,
# initial values, and parameters. We'd just like to load up the
# cores for the first jobs in the list, then add jobs as earlier
# jobs finish.
tasks = Any[]
tasks_started_TF = Bool[]
tasks_fetched_TF = Bool[]
task_numbers = Any[]
task_inc = 0
are_we_done = false
current_running_tasks = Any[]
# List the tasks
for i in 1:number_of_solves
# Temporary u
solve_results2[i,:] .= 0.0
# Change the ith state from 0.0 to 1.0
solve_results2[i,i] = 1.0
task_inc = task_inc + 1
push!(tasks_started_TF, false) # Add a "false" to tasks_started_TF
push!(tasks_fetched_TF, false) # Add a "false" to tasks_fetched_TF
push!(task_numbers, task_inc)
# Total number of tasks
num_tasks = length(tasks_fetched_TF)
iteration_number = 0
while(are_we_done == false)
iteration_number = iteration_number+1
# Launch tasks when thread (core) is available
for j in 1:num_tasks
if (tasks_fetched_TF[j] == false)
if (tasks_started_TF[j] == false) && (curr_numthreads > 0)
# Start a task
push!(tasks, Base.Threads.#spawn core_op_simd(solve_results2[j,:], tspan, p_Ds_v7))
curr_numthreads = curr_numthreads-1;
tasks_started_TF[j] = true;
push!(current_running_tasks, task_numbers[j])
# Check for finished tasks
tasks_to_check_TF = ((tasks_started_TF.==true) .+ (tasks_fetched_TF.==false)).==2
if sum(tasks_to_check_TF .== true) > 0
for k in 1:sum(tasks_to_check_TF)
if (tasks_fetched_TF[current_running_tasks[k]] == false)
if (istaskstarted(tasks[k]) == true) && (istaskdone(tasks[k]) == true)
sol_Ds_v7 = fetch(tasks[k]);
solve_results2[current_running_tasks[k],:] .= sol_Ds_v7.u[length(sol_Ds_v7.u)].+0.0
tasks_fetched_TF[current_running_tasks[k]] = true
current_tasknum = current_running_tasks[k]
deleteat!(tasks, k)
deleteat!(current_running_tasks, k)
curr_numthreads = curr_numthreads+1;
print("\nFinished task #")
print(", current task k=")
break # break out of this loop, since you have modified current_running_tasks
are_we_done = sum(tasks_fetched_TF) == length(tasks_fetched_TF)
# Test for concluding the while loop
are_we_done && break
end # END while(are_we_done == false)
end_time =
duration = (end_time - start_time).value / 1000.0
sum_of_solutions = sum(sum.(solve_results2))
return (duration, sum_of_solutions)
tspan = (0.0, 1.0)
parallel_with_plain_v5(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Faster than serial plain version
# (duration, sum_of_solutions)
# (0.351, 8.731365050398926)
# (0.343, 8.731365050398926)
# (0.366, 8.731365050398926)
parallel_with_simd_v7(tspan, p_Ds_v7, solve_results2; number_of_solves=number_of_solves)
# Dramatically slower than serial simd version
# (duration, sum_of_solutions)
# (136.966, 9.61313614002137)
# (141.843, 9.616688089683372)
Thanks again, Nick
