1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

Added all methods for RL

This commit is contained in:
MoBit 2021-12-12 18:27:56 +01:00
parent 029f7f29f2
commit dc6a4ee7b7
6 changed files with 212 additions and 48 deletions

View file

@ -13,6 +13,10 @@ mutable struct Particle
end end
end end
function gen_tmp_particle()
return Particle(0, SVector(0.0, 0.0), 0.0)
end
function restrict_coordinate(value::Float64, half_box_len::Float64) function restrict_coordinate(value::Float64, half_box_len::Float64)
if value < -half_box_len if value < -half_box_len
value += 2 * half_box_len value += 2 * half_box_len

View file

@ -2,7 +2,6 @@ module ReCo
export init_sim, run_sim export init_sim, run_sim
include("utils.jl")
include("PreVector.jl") include("PreVector.jl")
include("Particle.jl") include("Particle.jl")
include("reinforcement_learning.jl") include("reinforcement_learning.jl")

View file

@ -1,7 +1,9 @@
module ReCoRL module ReCoRL
using ReinforcementLearning using ReinforcementLearning
using Flux: InvDecay
using Intervals using Intervals
using StaticArrays: SVector
import Base: run import Base: run
@ -87,7 +89,11 @@ mutable struct EnvParams
) )
end end
state_space = Vector{Tuple{DistanceState,DirectionState}}(undef, n_states) state_space = Vector{
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
}(
undef, n_states
)
ind = 1 ind = 1
for distance_state in distance_state_space for distance_state in distance_state_space
@ -114,21 +120,48 @@ mutable struct Env <: AbstractEnv
end end
end end
struct Params function gen_policy(n_states, n_actions)
return QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
),
),
explorer=EpsilonGreedyExplorer(0.1),
)
end
struct Params{H<:AbstractHook}
envs::Vector{Env} envs::Vector{Env}
# agents agents::Vector{Agent}
hooks::Vector{H}
actions::Vector{Tuple{Float64,Float64}} actions::Vector{Tuple{Float64,Float64}}
env_params::EnvParams env_params::EnvParams
n_steps_before_actions_update::Int64 n_steps_before_actions_update::Int64
min_distance²::Vector{Float64}
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
function Params( function Params{H}(
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64 n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
) ) where {H<:AbstractHook}
policies = [
gen_policy(
length(rl_params.env_params.state_space),
length(rl_params.env_params.action_space),
) for i in 1:n_particles
]
agents = [
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
]
return new( return new(
Vector{Env}(undef, n_particles), [Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
agents,
[H() for i in 1:n_particles],
Vector{Tuple{Float64,Float64}}(undef, n_particles), Vector{Tuple{Float64,Float64}}(undef, n_particles),
env_params, env_params,
n_steps_before_actions_update, n_steps_before_actions_update,
zeros(n_particles),
fill(SVector(0.0, 0.0), n_particles),
) )
end end
end end
@ -141,17 +174,103 @@ RLBase.action_space(env::Env) = env.params.action_space
RLBase.reward(env::Env) = env.params.reward RLBase.reward(env::Env) = env.params.reward
function pre_integration_hook!() end function pre_integration_hook!(rl_params::Params, n_particles::Int64)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
action = agent(env)
rl_params.actions[i] = action
function integration_hook() end agent(PRE_ACT_STAGE, env, action)
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action)
end
function post_integration_hook() end return nothing
end
function state_hook(
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
)
if rl_params.min_distance²[id1] > distance²
rl_params.min_distance²[id1] = distance²
rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
end
if rl_params.min_distance²[id2] > distance²
rl_params.min_distance²[id2] = distance²
rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
end
return nothing
end
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
action = rl_params.actions[particle.id]
particle.tmp_c += action[1] * δt
particle.φ += action[2] * δt
return nothing
end
function post_integration_hook(
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
)
env_direction_state = rl_params.env_params.direction_state_space[1]
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
min_distance = sqrt(rl_params.min_distance²[i])
env_distance_state::Union{DistanceState,Nothing} = nothing
for distance_state in rl_params.env_params.distance_state_space
if min_distance in distance_state.interval
env_distance_state = distance_state
break
end
end
if isnothing(env_distance_state)
env.state = (nothing, nothing)
else
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
si, co = sincos(particles[i].φ)
#=
Angle between two vectors
e = (co, si)
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
norm(r⃗₁₂) == min_distance
norm(e) == 1
=#
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
for direction_state in rl_params.env_params.direction_state_space
if direction in direction_state
env_direction_state = direction_state
end
end
env.state = (env_distance_state, env_direction_state)
end
agent(POST_ACT_STAGE, env)
rl_params.hooks[i](POST_ACT_STAGE, agent, env)
end
return nothing
end
function run( function run(
n_episodes::Int64=100, n_episodes::Int64=100,
episode_duration::Float64=5.0, episode_duration::Float64=5.0,
update_actions_at::Float64=0.1, update_actions_at::Float64=0.1,
n_particles::Int64=100, n_particles::Int64=10,
) )
@assert n_episodes > 0 @assert n_episodes > 0
@assert episode_duration > 0 @assert episode_duration > 0
@ -161,32 +280,61 @@ function run(
Random.seed!(42) Random.seed!(42)
# envs
# agents
# pre_experiment
sim_consts = gen_sim_consts(n_particles, v₀) sim_consts = gen_sim_consts(n_particles, v₀)
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r) env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
rl_params = Params(n_particles, env_params, n_steps_before_actions_update) rl_params = Params{TotalRewardPerEpisode}(
n_particles, env_params, n_steps_before_actions_update
)
for episode in 1:n_episodes for i in 1:n_particles
# reset env = rl_params.envs[i]
# pre_episode agent = rl_params.agents[i]
dir = init_sim_with_sim_consts(; sim_consts, parent_dir="RL") rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
run_sim(
dir;
duration=episode_duration,
seed=rand(1:typemax(Int64)),
rl_params=rl_params,
skin_r=skin_r,
)
end end
return nothing for episode in 1:n_episodes
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
end
for i in 1:n_particles
rl_params.envs[i].particle = particles[i]
rl_params.envs[i].state = (nothing, nothing)
end
rl_params.env_params.reward = 0.0
run_sim(
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
end
end
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
end
return rl_params.hooks
end end
end # module end # module

View file

@ -78,6 +78,7 @@ function run_sim(
args = ( args = (
v₀=sim_consts.v₀, v₀=sim_consts.v₀,
δt=sim_consts.δt,
skin_r=sim_consts.skin_r, skin_r=sim_consts.skin_r,
skin_r²=sim_consts.skin_r^2, skin_r²=sim_consts.skin_r^2,
n_snapshots=n_snapshots, n_snapshots=n_snapshots,

View file

@ -114,31 +114,31 @@ function init_sim_with_sim_consts(
bundle = Bundle(n_particles, 1) bundle = Bundle(n_particles, 1)
save_snapshot!(bundle, 1, 0.0, particles) save_snapshot!(bundle, 1, 0.0, particles)
particles = nothing dir = exports_dir
if length(parent_dir) > 0 if length(parent_dir) > 0
exports_dir *= "/$parent_dir" dir *= "/$parent_dir"
end end
start_datetime = now() start_datetime = now()
exports_dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))" dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
if length(comment) > 0 if length(comment) > 0
exports_dir *= "_$comment" dir *= "_$comment"
end end
mkpath(exports_dir) mkpath(dir)
task = @async write_struct_to_json(sim_consts, "$exports_dir/sim_consts") task = @async write_struct_to_json(sim_consts, "$dir/sim_consts")
save_bundle(exports_dir, bundle, 1, 0.0) save_bundle(dir, bundle, 1, 0.0)
runs_dir = "$exports_dir/runs" runs_dir = "$dir/runs"
mkpath(runs_dir) mkpath(runs_dir)
wait(task) wait(task)
return exports_dir return (dir, particles)
end end
function init_sim(; function init_sim(;
@ -155,5 +155,5 @@ function init_sim(;
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
) )
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment) return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
end end

View file

@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
end end
function euler!( function euler!(
args, integration_hook::F, actions::Vector{Tuple{Float64,Float64}} args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
) where {F<:Function} ) where {F<:Function}
for id1 in 1:(args.n_particles - 1) for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1] p1 = args.particles[id1]
@ -55,6 +55,8 @@ function euler!(
p1_c, p2.c, args.interaction_r², args.half_box_len p1_c, p2.c, args.interaction_r², args.half_box_len
) )
state_hook(id1, id2, r⃗₁₂, distance², rl_params)
if overlapping if overlapping
c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0) c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
@ -76,7 +78,7 @@ function euler!(
restrict_coordinates!(p, args.half_box_len) restrict_coordinates!(p, args.half_box_len)
integration_hook(p, actions) integration_hook(p, rl_params, args.δt)
p.c = p.tmp_c p.c = p.tmp_c
end end
@ -86,6 +88,13 @@ end
wait(::Nothing) = nothing wait(::Nothing) = nothing
gen_run_hooks(::Nothing, args...) = false
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
return (itegration_step == 1) ||
(integration_step % rl_params.n_steps_before_actions_update == 0)
end
function simulate( function simulate(
args, args,
δt::Float64, δt::Float64,
@ -108,7 +117,8 @@ function simulate(
cl = CellList(args.particles_c, args.box; parallel=false) cl = CellList(args.particles_c, args.box; parallel=false)
cl = update_verlet_lists!(args, cl) cl = update_verlet_lists!(args, cl)
update_actions = true run_hooks = false
state_hook = empty_hook
start_time = now() start_time = now()
println("Started simulation at $start_time.") println("Started simulation at $start_time.")
@ -134,16 +144,18 @@ function simulate(
cl = update_verlet_lists!(args, cl) cl = update_verlet_lists!(args, cl)
end end
update_actions = integration_step % rl_params.n_steps_before_actions_update == 0 run_hooks = gen_run_hooks(rl_params, integration_step)
if update_actions if run_hooks
pre_integration_hook!(rl_params) pre_integration_hook!(rl_params, args.n_particles)
state_hook = ReCoRL.state_hook
end end
euler!(args, integration_hook, rl.params.actions) euler!(args, state_hook, integration_hook, rl.params.actions)
if update_actions if run_hooks
post_integration_hook(rl_params) post_integration_hook(rl_params, args.n_particles, args.particles)
state_hook = empty_hook
end end
end end