mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2025-01-01 15:39:20 +00:00
Added all methods for RL
This commit is contained in:
parent
029f7f29f2
commit
dc6a4ee7b7
6 changed files with 212 additions and 48 deletions
|
@ -13,6 +13,10 @@ mutable struct Particle
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function gen_tmp_particle()
|
||||||
|
return Particle(0, SVector(0.0, 0.0), 0.0)
|
||||||
|
end
|
||||||
|
|
||||||
function restrict_coordinate(value::Float64, half_box_len::Float64)
|
function restrict_coordinate(value::Float64, half_box_len::Float64)
|
||||||
if value < -half_box_len
|
if value < -half_box_len
|
||||||
value += 2 * half_box_len
|
value += 2 * half_box_len
|
||||||
|
|
|
@ -2,7 +2,6 @@ module ReCo
|
||||||
|
|
||||||
export init_sim, run_sim
|
export init_sim, run_sim
|
||||||
|
|
||||||
include("utils.jl")
|
|
||||||
include("PreVector.jl")
|
include("PreVector.jl")
|
||||||
include("Particle.jl")
|
include("Particle.jl")
|
||||||
include("reinforcement_learning.jl")
|
include("reinforcement_learning.jl")
|
||||||
|
|
|
@ -1,7 +1,9 @@
|
||||||
module ReCoRL
|
module ReCoRL
|
||||||
|
|
||||||
using ReinforcementLearning
|
using ReinforcementLearning
|
||||||
|
using Flux: InvDecay
|
||||||
using Intervals
|
using Intervals
|
||||||
|
using StaticArrays: SVector
|
||||||
|
|
||||||
import Base: run
|
import Base: run
|
||||||
|
|
||||||
|
@ -87,7 +89,11 @@ mutable struct EnvParams
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
state_space = Vector{Tuple{DistanceState,DirectionState}}(undef, n_states)
|
state_space = Vector{
|
||||||
|
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
||||||
|
}(
|
||||||
|
undef, n_states
|
||||||
|
)
|
||||||
|
|
||||||
ind = 1
|
ind = 1
|
||||||
for distance_state in distance_state_space
|
for distance_state in distance_state_space
|
||||||
|
@ -114,21 +120,48 @@ mutable struct Env <: AbstractEnv
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
struct Params
|
function gen_policy(n_states, n_actions)
|
||||||
|
return QBasedPolicy(;
|
||||||
|
learner=MonteCarloLearner(;
|
||||||
|
approximator=TabularQApproximator(;
|
||||||
|
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
||||||
|
),
|
||||||
|
),
|
||||||
|
explorer=EpsilonGreedyExplorer(0.1),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
struct Params{H<:AbstractHook}
|
||||||
envs::Vector{Env}
|
envs::Vector{Env}
|
||||||
# agents
|
agents::Vector{Agent}
|
||||||
|
hooks::Vector{H}
|
||||||
actions::Vector{Tuple{Float64,Float64}}
|
actions::Vector{Tuple{Float64,Float64}}
|
||||||
env_params::EnvParams
|
env_params::EnvParams
|
||||||
n_steps_before_actions_update::Int64
|
n_steps_before_actions_update::Int64
|
||||||
|
min_distance²::Vector{Float64}
|
||||||
|
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
||||||
|
|
||||||
function Params(
|
function Params{H}(
|
||||||
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
|
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
|
||||||
)
|
) where {H<:AbstractHook}
|
||||||
|
policies = [
|
||||||
|
gen_policy(
|
||||||
|
length(rl_params.env_params.state_space),
|
||||||
|
length(rl_params.env_params.action_space),
|
||||||
|
) for i in 1:n_particles
|
||||||
|
]
|
||||||
|
agents = [
|
||||||
|
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
|
||||||
|
]
|
||||||
return new(
|
return new(
|
||||||
Vector{Env}(undef, n_particles),
|
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
|
||||||
|
agents,
|
||||||
|
[H() for i in 1:n_particles],
|
||||||
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
||||||
env_params,
|
env_params,
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
|
zeros(n_particles),
|
||||||
|
fill(SVector(0.0, 0.0), n_particles),
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -141,17 +174,103 @@ RLBase.action_space(env::Env) = env.params.action_space
|
||||||
|
|
||||||
RLBase.reward(env::Env) = env.params.reward
|
RLBase.reward(env::Env) = env.params.reward
|
||||||
|
|
||||||
function pre_integration_hook!() end
|
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
|
||||||
|
for i in 1:n_particles
|
||||||
|
env = rl_params.envs[i]
|
||||||
|
agent = rl_params.agents[i]
|
||||||
|
action = agent(env)
|
||||||
|
rl_params.actions[i] = action
|
||||||
|
|
||||||
function integration_hook() end
|
agent(PRE_ACT_STAGE, env, action)
|
||||||
|
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action)
|
||||||
|
end
|
||||||
|
|
||||||
function post_integration_hook() end
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function state_hook(
|
||||||
|
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
|
||||||
|
)
|
||||||
|
if rl_params.min_distance²[id1] > distance²
|
||||||
|
rl_params.min_distance²[id1] = distance²
|
||||||
|
|
||||||
|
rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
|
||||||
|
end
|
||||||
|
|
||||||
|
if rl_params.min_distance²[id2] > distance²
|
||||||
|
rl_params.min_distance²[id2] = distance²
|
||||||
|
|
||||||
|
rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
|
||||||
|
action = rl_params.actions[particle.id]
|
||||||
|
|
||||||
|
particle.tmp_c += action[1] * δt
|
||||||
|
particle.φ += action[2] * δt
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
|
function post_integration_hook(
|
||||||
|
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
|
||||||
|
)
|
||||||
|
env_direction_state = rl_params.env_params.direction_state_space[1]
|
||||||
|
|
||||||
|
for i in 1:n_particles
|
||||||
|
env = rl_params.envs[i]
|
||||||
|
agent = rl_params.agents[i]
|
||||||
|
|
||||||
|
min_distance = sqrt(rl_params.min_distance²[i])
|
||||||
|
|
||||||
|
env_distance_state::Union{DistanceState,Nothing} = nothing
|
||||||
|
|
||||||
|
for distance_state in rl_params.env_params.distance_state_space
|
||||||
|
if min_distance in distance_state.interval
|
||||||
|
env_distance_state = distance_state
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if isnothing(env_distance_state)
|
||||||
|
env.state = (nothing, nothing)
|
||||||
|
else
|
||||||
|
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
|
||||||
|
si, co = sincos(particles[i].φ)
|
||||||
|
|
||||||
|
#=
|
||||||
|
Angle between two vectors
|
||||||
|
e = (co, si)
|
||||||
|
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
|
||||||
|
norm(r⃗₁₂) == min_distance
|
||||||
|
norm(e) == 1
|
||||||
|
=#
|
||||||
|
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
|
||||||
|
|
||||||
|
for direction_state in rl_params.env_params.direction_state_space
|
||||||
|
if direction in direction_state
|
||||||
|
env_direction_state = direction_state
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
env.state = (env_distance_state, env_direction_state)
|
||||||
|
end
|
||||||
|
|
||||||
|
agent(POST_ACT_STAGE, env)
|
||||||
|
rl_params.hooks[i](POST_ACT_STAGE, agent, env)
|
||||||
|
end
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
function run(
|
function run(
|
||||||
n_episodes::Int64=100,
|
n_episodes::Int64=100,
|
||||||
episode_duration::Float64=5.0,
|
episode_duration::Float64=5.0,
|
||||||
update_actions_at::Float64=0.1,
|
update_actions_at::Float64=0.1,
|
||||||
n_particles::Int64=100,
|
n_particles::Int64=10,
|
||||||
)
|
)
|
||||||
@assert n_episodes > 0
|
@assert n_episodes > 0
|
||||||
@assert episode_duration > 0
|
@assert episode_duration > 0
|
||||||
|
@ -161,32 +280,61 @@ function run(
|
||||||
|
|
||||||
Random.seed!(42)
|
Random.seed!(42)
|
||||||
|
|
||||||
# envs
|
|
||||||
# agents
|
|
||||||
# pre_experiment
|
|
||||||
|
|
||||||
sim_consts = gen_sim_consts(n_particles, v₀)
|
sim_consts = gen_sim_consts(n_particles, v₀)
|
||||||
|
|
||||||
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
||||||
|
|
||||||
rl_params = Params(n_particles, env_params, n_steps_before_actions_update)
|
rl_params = Params{TotalRewardPerEpisode}(
|
||||||
|
n_particles, env_params, n_steps_before_actions_update
|
||||||
|
)
|
||||||
|
|
||||||
for episode in 1:n_episodes
|
for i in 1:n_particles
|
||||||
# reset
|
env = rl_params.envs[i]
|
||||||
# pre_episode
|
agent = rl_params.agents[i]
|
||||||
|
|
||||||
dir = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
|
rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env)
|
||||||
|
agent(PRE_EXPERIMENT_STAGE, env)
|
||||||
run_sim(
|
|
||||||
dir;
|
|
||||||
duration=episode_duration,
|
|
||||||
seed=rand(1:typemax(Int64)),
|
|
||||||
rl_params=rl_params,
|
|
||||||
skin_r=skin_r,
|
|
||||||
)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return nothing
|
for episode in 1:n_episodes
|
||||||
|
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
|
||||||
|
|
||||||
|
for i in 1:n_particles
|
||||||
|
env = rl_params.envs[i]
|
||||||
|
agent = rl_params.agents[i]
|
||||||
|
|
||||||
|
rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env)
|
||||||
|
agent(PRE_EPISODE_STAGE, env)
|
||||||
|
end
|
||||||
|
|
||||||
|
for i in 1:n_particles
|
||||||
|
rl_params.envs[i].particle = particles[i]
|
||||||
|
rl_params.envs[i].state = (nothing, nothing)
|
||||||
|
end
|
||||||
|
|
||||||
|
rl_params.env_params.reward = 0.0
|
||||||
|
|
||||||
|
run_sim(
|
||||||
|
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
||||||
|
)
|
||||||
|
|
||||||
|
for i in 1:n_particles
|
||||||
|
env = rl_params.envs[i]
|
||||||
|
agent = rl_params.agents[i]
|
||||||
|
|
||||||
|
rl_params.hooks[i](POST_EPISODE_STAGE, agent, env)
|
||||||
|
agent(POST_EPISODE_STAGE, env)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
for i in 1:n_particles
|
||||||
|
env = rl_params.envs[i]
|
||||||
|
agent = rl_params.agents[i]
|
||||||
|
|
||||||
|
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
|
||||||
|
end
|
||||||
|
|
||||||
|
return rl_params.hooks
|
||||||
end
|
end
|
||||||
|
|
||||||
end # module
|
end # module
|
|
@ -78,6 +78,7 @@ function run_sim(
|
||||||
|
|
||||||
args = (
|
args = (
|
||||||
v₀=sim_consts.v₀,
|
v₀=sim_consts.v₀,
|
||||||
|
δt=sim_consts.δt,
|
||||||
skin_r=sim_consts.skin_r,
|
skin_r=sim_consts.skin_r,
|
||||||
skin_r²=sim_consts.skin_r^2,
|
skin_r²=sim_consts.skin_r^2,
|
||||||
n_snapshots=n_snapshots,
|
n_snapshots=n_snapshots,
|
||||||
|
|
20
src/setup.jl
20
src/setup.jl
|
@ -114,31 +114,31 @@ function init_sim_with_sim_consts(
|
||||||
bundle = Bundle(n_particles, 1)
|
bundle = Bundle(n_particles, 1)
|
||||||
save_snapshot!(bundle, 1, 0.0, particles)
|
save_snapshot!(bundle, 1, 0.0, particles)
|
||||||
|
|
||||||
particles = nothing
|
dir = exports_dir
|
||||||
|
|
||||||
if length(parent_dir) > 0
|
if length(parent_dir) > 0
|
||||||
exports_dir *= "/$parent_dir"
|
dir *= "/$parent_dir"
|
||||||
end
|
end
|
||||||
|
|
||||||
start_datetime = now()
|
start_datetime = now()
|
||||||
exports_dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
|
dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
|
||||||
|
|
||||||
if length(comment) > 0
|
if length(comment) > 0
|
||||||
exports_dir *= "_$comment"
|
dir *= "_$comment"
|
||||||
end
|
end
|
||||||
|
|
||||||
mkpath(exports_dir)
|
mkpath(dir)
|
||||||
|
|
||||||
task = @async write_struct_to_json(sim_consts, "$exports_dir/sim_consts")
|
task = @async write_struct_to_json(sim_consts, "$dir/sim_consts")
|
||||||
|
|
||||||
save_bundle(exports_dir, bundle, 1, 0.0)
|
save_bundle(dir, bundle, 1, 0.0)
|
||||||
|
|
||||||
runs_dir = "$exports_dir/runs"
|
runs_dir = "$dir/runs"
|
||||||
mkpath(runs_dir)
|
mkpath(runs_dir)
|
||||||
|
|
||||||
wait(task)
|
wait(task)
|
||||||
|
|
||||||
return exports_dir
|
return (dir, particles)
|
||||||
end
|
end
|
||||||
|
|
||||||
function init_sim(;
|
function init_sim(;
|
||||||
|
@ -155,5 +155,5 @@ function init_sim(;
|
||||||
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
|
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
|
||||||
)
|
)
|
||||||
|
|
||||||
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)
|
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
|
||||||
end
|
end
|
|
@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
|
||||||
end
|
end
|
||||||
|
|
||||||
function euler!(
|
function euler!(
|
||||||
args, integration_hook::F, actions::Vector{Tuple{Float64,Float64}}
|
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
|
||||||
) where {F<:Function}
|
) where {F<:Function}
|
||||||
for id1 in 1:(args.n_particles - 1)
|
for id1 in 1:(args.n_particles - 1)
|
||||||
p1 = args.particles[id1]
|
p1 = args.particles[id1]
|
||||||
|
@ -55,6 +55,8 @@ function euler!(
|
||||||
p1_c, p2.c, args.interaction_r², args.half_box_len
|
p1_c, p2.c, args.interaction_r², args.half_box_len
|
||||||
)
|
)
|
||||||
|
|
||||||
|
state_hook(id1, id2, r⃗₁₂, distance², rl_params)
|
||||||
|
|
||||||
if overlapping
|
if overlapping
|
||||||
c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
||||||
|
|
||||||
|
@ -76,7 +78,7 @@ function euler!(
|
||||||
|
|
||||||
restrict_coordinates!(p, args.half_box_len)
|
restrict_coordinates!(p, args.half_box_len)
|
||||||
|
|
||||||
integration_hook(p, actions)
|
integration_hook(p, rl_params, args.δt)
|
||||||
|
|
||||||
p.c = p.tmp_c
|
p.c = p.tmp_c
|
||||||
end
|
end
|
||||||
|
@ -86,6 +88,13 @@ end
|
||||||
|
|
||||||
wait(::Nothing) = nothing
|
wait(::Nothing) = nothing
|
||||||
|
|
||||||
|
gen_run_hooks(::Nothing, args...) = false
|
||||||
|
|
||||||
|
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
|
||||||
|
return (itegration_step == 1) ||
|
||||||
|
(integration_step % rl_params.n_steps_before_actions_update == 0)
|
||||||
|
end
|
||||||
|
|
||||||
function simulate(
|
function simulate(
|
||||||
args,
|
args,
|
||||||
δt::Float64,
|
δt::Float64,
|
||||||
|
@ -108,7 +117,8 @@ function simulate(
|
||||||
cl = CellList(args.particles_c, args.box; parallel=false)
|
cl = CellList(args.particles_c, args.box; parallel=false)
|
||||||
cl = update_verlet_lists!(args, cl)
|
cl = update_verlet_lists!(args, cl)
|
||||||
|
|
||||||
update_actions = true
|
run_hooks = false
|
||||||
|
state_hook = empty_hook
|
||||||
|
|
||||||
start_time = now()
|
start_time = now()
|
||||||
println("Started simulation at $start_time.")
|
println("Started simulation at $start_time.")
|
||||||
|
@ -134,16 +144,18 @@ function simulate(
|
||||||
cl = update_verlet_lists!(args, cl)
|
cl = update_verlet_lists!(args, cl)
|
||||||
end
|
end
|
||||||
|
|
||||||
update_actions = integration_step % rl_params.n_steps_before_actions_update == 0
|
run_hooks = gen_run_hooks(rl_params, integration_step)
|
||||||
|
|
||||||
if update_actions
|
if run_hooks
|
||||||
pre_integration_hook!(rl_params)
|
pre_integration_hook!(rl_params, args.n_particles)
|
||||||
|
state_hook = ReCoRL.state_hook
|
||||||
end
|
end
|
||||||
|
|
||||||
euler!(args, integration_hook, rl.params.actions)
|
euler!(args, state_hook, integration_hook, rl.params.actions)
|
||||||
|
|
||||||
if update_actions
|
if run_hooks
|
||||||
post_integration_hook(rl_params)
|
post_integration_hook(rl_params, args.n_particles, args.particles)
|
||||||
|
state_hook = empty_hook
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue