mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2025-01-01 15:39:20 +00:00
Added all methods for RL
This commit is contained in:
parent
029f7f29f2
commit
dc6a4ee7b7
6 changed files with 212 additions and 48 deletions
|
@ -13,6 +13,10 @@ mutable struct Particle
|
|||
end
|
||||
end
|
||||
|
||||
function gen_tmp_particle()
|
||||
return Particle(0, SVector(0.0, 0.0), 0.0)
|
||||
end
|
||||
|
||||
function restrict_coordinate(value::Float64, half_box_len::Float64)
|
||||
if value < -half_box_len
|
||||
value += 2 * half_box_len
|
||||
|
|
|
@ -2,7 +2,6 @@ module ReCo
|
|||
|
||||
export init_sim, run_sim
|
||||
|
||||
include("utils.jl")
|
||||
include("PreVector.jl")
|
||||
include("Particle.jl")
|
||||
include("reinforcement_learning.jl")
|
||||
|
|
|
@ -1,7 +1,9 @@
|
|||
module ReCoRL
|
||||
|
||||
using ReinforcementLearning
|
||||
using Flux: InvDecay
|
||||
using Intervals
|
||||
using StaticArrays: SVector
|
||||
|
||||
import Base: run
|
||||
|
||||
|
@ -87,7 +89,11 @@ mutable struct EnvParams
|
|||
)
|
||||
end
|
||||
|
||||
state_space = Vector{Tuple{DistanceState,DirectionState}}(undef, n_states)
|
||||
state_space = Vector{
|
||||
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
|
||||
}(
|
||||
undef, n_states
|
||||
)
|
||||
|
||||
ind = 1
|
||||
for distance_state in distance_state_space
|
||||
|
@ -114,21 +120,48 @@ mutable struct Env <: AbstractEnv
|
|||
end
|
||||
end
|
||||
|
||||
struct Params
|
||||
function gen_policy(n_states, n_actions)
|
||||
return QBasedPolicy(;
|
||||
learner=MonteCarloLearner(;
|
||||
approximator=TabularQApproximator(;
|
||||
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
||||
),
|
||||
),
|
||||
explorer=EpsilonGreedyExplorer(0.1),
|
||||
)
|
||||
end
|
||||
|
||||
struct Params{H<:AbstractHook}
|
||||
envs::Vector{Env}
|
||||
# agents
|
||||
agents::Vector{Agent}
|
||||
hooks::Vector{H}
|
||||
actions::Vector{Tuple{Float64,Float64}}
|
||||
env_params::EnvParams
|
||||
n_steps_before_actions_update::Int64
|
||||
min_distance²::Vector{Float64}
|
||||
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
|
||||
|
||||
function Params(
|
||||
function Params{H}(
|
||||
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
|
||||
)
|
||||
) where {H<:AbstractHook}
|
||||
policies = [
|
||||
gen_policy(
|
||||
length(rl_params.env_params.state_space),
|
||||
length(rl_params.env_params.action_space),
|
||||
) for i in 1:n_particles
|
||||
]
|
||||
agents = [
|
||||
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
|
||||
]
|
||||
return new(
|
||||
Vector{Env}(undef, n_particles),
|
||||
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
|
||||
agents,
|
||||
[H() for i in 1:n_particles],
|
||||
Vector{Tuple{Float64,Float64}}(undef, n_particles),
|
||||
env_params,
|
||||
n_steps_before_actions_update,
|
||||
zeros(n_particles),
|
||||
fill(SVector(0.0, 0.0), n_particles),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
@ -141,17 +174,103 @@ RLBase.action_space(env::Env) = env.params.action_space
|
|||
|
||||
RLBase.reward(env::Env) = env.params.reward
|
||||
|
||||
function pre_integration_hook!() end
|
||||
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
agent = rl_params.agents[i]
|
||||
action = agent(env)
|
||||
rl_params.actions[i] = action
|
||||
|
||||
function integration_hook() end
|
||||
agent(PRE_ACT_STAGE, env, action)
|
||||
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action)
|
||||
end
|
||||
|
||||
function post_integration_hook() end
|
||||
return nothing
|
||||
end
|
||||
|
||||
function state_hook(
|
||||
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
|
||||
)
|
||||
if rl_params.min_distance²[id1] > distance²
|
||||
rl_params.min_distance²[id1] = distance²
|
||||
|
||||
rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
|
||||
end
|
||||
|
||||
if rl_params.min_distance²[id2] > distance²
|
||||
rl_params.min_distance²[id2] = distance²
|
||||
|
||||
rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
|
||||
action = rl_params.actions[particle.id]
|
||||
|
||||
particle.tmp_c += action[1] * δt
|
||||
particle.φ += action[2] * δt
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function post_integration_hook(
|
||||
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
|
||||
)
|
||||
env_direction_state = rl_params.env_params.direction_state_space[1]
|
||||
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
agent = rl_params.agents[i]
|
||||
|
||||
min_distance = sqrt(rl_params.min_distance²[i])
|
||||
|
||||
env_distance_state::Union{DistanceState,Nothing} = nothing
|
||||
|
||||
for distance_state in rl_params.env_params.distance_state_space
|
||||
if min_distance in distance_state.interval
|
||||
env_distance_state = distance_state
|
||||
break
|
||||
end
|
||||
end
|
||||
|
||||
if isnothing(env_distance_state)
|
||||
env.state = (nothing, nothing)
|
||||
else
|
||||
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
|
||||
si, co = sincos(particles[i].φ)
|
||||
|
||||
#=
|
||||
Angle between two vectors
|
||||
e = (co, si)
|
||||
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
|
||||
norm(r⃗₁₂) == min_distance
|
||||
norm(e) == 1
|
||||
=#
|
||||
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
|
||||
|
||||
for direction_state in rl_params.env_params.direction_state_space
|
||||
if direction in direction_state
|
||||
env_direction_state = direction_state
|
||||
end
|
||||
end
|
||||
|
||||
env.state = (env_distance_state, env_direction_state)
|
||||
end
|
||||
|
||||
agent(POST_ACT_STAGE, env)
|
||||
rl_params.hooks[i](POST_ACT_STAGE, agent, env)
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function run(
|
||||
n_episodes::Int64=100,
|
||||
episode_duration::Float64=5.0,
|
||||
update_actions_at::Float64=0.1,
|
||||
n_particles::Int64=100,
|
||||
n_particles::Int64=10,
|
||||
)
|
||||
@assert n_episodes > 0
|
||||
@assert episode_duration > 0
|
||||
|
@ -161,32 +280,61 @@ function run(
|
|||
|
||||
Random.seed!(42)
|
||||
|
||||
# envs
|
||||
# agents
|
||||
# pre_experiment
|
||||
|
||||
sim_consts = gen_sim_consts(n_particles, v₀)
|
||||
|
||||
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
|
||||
|
||||
rl_params = Params(n_particles, env_params, n_steps_before_actions_update)
|
||||
|
||||
for episode in 1:n_episodes
|
||||
# reset
|
||||
# pre_episode
|
||||
|
||||
dir = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
|
||||
|
||||
run_sim(
|
||||
dir;
|
||||
duration=episode_duration,
|
||||
seed=rand(1:typemax(Int64)),
|
||||
rl_params=rl_params,
|
||||
skin_r=skin_r,
|
||||
rl_params = Params{TotalRewardPerEpisode}(
|
||||
n_particles, env_params, n_steps_before_actions_update
|
||||
)
|
||||
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
agent = rl_params.agents[i]
|
||||
|
||||
rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env)
|
||||
agent(PRE_EXPERIMENT_STAGE, env)
|
||||
end
|
||||
|
||||
return nothing
|
||||
for episode in 1:n_episodes
|
||||
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
|
||||
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
agent = rl_params.agents[i]
|
||||
|
||||
rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env)
|
||||
agent(PRE_EPISODE_STAGE, env)
|
||||
end
|
||||
|
||||
for i in 1:n_particles
|
||||
rl_params.envs[i].particle = particles[i]
|
||||
rl_params.envs[i].state = (nothing, nothing)
|
||||
end
|
||||
|
||||
rl_params.env_params.reward = 0.0
|
||||
|
||||
run_sim(
|
||||
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
|
||||
)
|
||||
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
agent = rl_params.agents[i]
|
||||
|
||||
rl_params.hooks[i](POST_EPISODE_STAGE, agent, env)
|
||||
agent(POST_EPISODE_STAGE, env)
|
||||
end
|
||||
end
|
||||
|
||||
for i in 1:n_particles
|
||||
env = rl_params.envs[i]
|
||||
agent = rl_params.agents[i]
|
||||
|
||||
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
|
||||
end
|
||||
|
||||
return rl_params.hooks
|
||||
end
|
||||
|
||||
end # module
|
|
@ -78,6 +78,7 @@ function run_sim(
|
|||
|
||||
args = (
|
||||
v₀=sim_consts.v₀,
|
||||
δt=sim_consts.δt,
|
||||
skin_r=sim_consts.skin_r,
|
||||
skin_r²=sim_consts.skin_r^2,
|
||||
n_snapshots=n_snapshots,
|
||||
|
|
20
src/setup.jl
20
src/setup.jl
|
@ -114,31 +114,31 @@ function init_sim_with_sim_consts(
|
|||
bundle = Bundle(n_particles, 1)
|
||||
save_snapshot!(bundle, 1, 0.0, particles)
|
||||
|
||||
particles = nothing
|
||||
dir = exports_dir
|
||||
|
||||
if length(parent_dir) > 0
|
||||
exports_dir *= "/$parent_dir"
|
||||
dir *= "/$parent_dir"
|
||||
end
|
||||
|
||||
start_datetime = now()
|
||||
exports_dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
|
||||
dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
|
||||
|
||||
if length(comment) > 0
|
||||
exports_dir *= "_$comment"
|
||||
dir *= "_$comment"
|
||||
end
|
||||
|
||||
mkpath(exports_dir)
|
||||
mkpath(dir)
|
||||
|
||||
task = @async write_struct_to_json(sim_consts, "$exports_dir/sim_consts")
|
||||
task = @async write_struct_to_json(sim_consts, "$dir/sim_consts")
|
||||
|
||||
save_bundle(exports_dir, bundle, 1, 0.0)
|
||||
save_bundle(dir, bundle, 1, 0.0)
|
||||
|
||||
runs_dir = "$exports_dir/runs"
|
||||
runs_dir = "$dir/runs"
|
||||
mkpath(runs_dir)
|
||||
|
||||
wait(task)
|
||||
|
||||
return exports_dir
|
||||
return (dir, particles)
|
||||
end
|
||||
|
||||
function init_sim(;
|
||||
|
@ -155,5 +155,5 @@ function init_sim(;
|
|||
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
|
||||
)
|
||||
|
||||
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)
|
||||
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
|
||||
end
|
|
@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
|
|||
end
|
||||
|
||||
function euler!(
|
||||
args, integration_hook::F, actions::Vector{Tuple{Float64,Float64}}
|
||||
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
|
||||
) where {F<:Function}
|
||||
for id1 in 1:(args.n_particles - 1)
|
||||
p1 = args.particles[id1]
|
||||
|
@ -55,6 +55,8 @@ function euler!(
|
|||
p1_c, p2.c, args.interaction_r², args.half_box_len
|
||||
)
|
||||
|
||||
state_hook(id1, id2, r⃗₁₂, distance², rl_params)
|
||||
|
||||
if overlapping
|
||||
c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
|
||||
|
||||
|
@ -76,7 +78,7 @@ function euler!(
|
|||
|
||||
restrict_coordinates!(p, args.half_box_len)
|
||||
|
||||
integration_hook(p, actions)
|
||||
integration_hook(p, rl_params, args.δt)
|
||||
|
||||
p.c = p.tmp_c
|
||||
end
|
||||
|
@ -86,6 +88,13 @@ end
|
|||
|
||||
wait(::Nothing) = nothing
|
||||
|
||||
gen_run_hooks(::Nothing, args...) = false
|
||||
|
||||
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
|
||||
return (itegration_step == 1) ||
|
||||
(integration_step % rl_params.n_steps_before_actions_update == 0)
|
||||
end
|
||||
|
||||
function simulate(
|
||||
args,
|
||||
δt::Float64,
|
||||
|
@ -108,7 +117,8 @@ function simulate(
|
|||
cl = CellList(args.particles_c, args.box; parallel=false)
|
||||
cl = update_verlet_lists!(args, cl)
|
||||
|
||||
update_actions = true
|
||||
run_hooks = false
|
||||
state_hook = empty_hook
|
||||
|
||||
start_time = now()
|
||||
println("Started simulation at $start_time.")
|
||||
|
@ -134,16 +144,18 @@ function simulate(
|
|||
cl = update_verlet_lists!(args, cl)
|
||||
end
|
||||
|
||||
update_actions = integration_step % rl_params.n_steps_before_actions_update == 0
|
||||
run_hooks = gen_run_hooks(rl_params, integration_step)
|
||||
|
||||
if update_actions
|
||||
pre_integration_hook!(rl_params)
|
||||
if run_hooks
|
||||
pre_integration_hook!(rl_params, args.n_particles)
|
||||
state_hook = ReCoRL.state_hook
|
||||
end
|
||||
|
||||
euler!(args, integration_hook, rl.params.actions)
|
||||
euler!(args, state_hook, integration_hook, rl.params.actions)
|
||||
|
||||
if update_actions
|
||||
post_integration_hook(rl_params)
|
||||
if run_hooks
|
||||
post_integration_hook(rl_params, args.n_particles, args.particles)
|
||||
state_hook = empty_hook
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in a new issue