1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2025-01-01 15:39:20 +00:00

Added all methods for RL

This commit is contained in:
MoBit 2021-12-12 18:27:56 +01:00
parent 029f7f29f2
commit dc6a4ee7b7
6 changed files with 212 additions and 48 deletions

View file

@ -13,6 +13,10 @@ mutable struct Particle
end
end
function gen_tmp_particle()
return Particle(0, SVector(0.0, 0.0), 0.0)
end
function restrict_coordinate(value::Float64, half_box_len::Float64)
if value < -half_box_len
value += 2 * half_box_len

View file

@ -2,7 +2,6 @@ module ReCo
export init_sim, run_sim
include("utils.jl")
include("PreVector.jl")
include("Particle.jl")
include("reinforcement_learning.jl")

View file

@ -1,7 +1,9 @@
module ReCoRL
using ReinforcementLearning
using Flux: InvDecay
using Intervals
using StaticArrays: SVector
import Base: run
@ -87,7 +89,11 @@ mutable struct EnvParams
)
end
state_space = Vector{Tuple{DistanceState,DirectionState}}(undef, n_states)
state_space = Vector{
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
}(
undef, n_states
)
ind = 1
for distance_state in distance_state_space
@ -114,21 +120,48 @@ mutable struct Env <: AbstractEnv
end
end
struct Params
function gen_policy(n_states, n_actions)
return QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
),
),
explorer=EpsilonGreedyExplorer(0.1),
)
end
struct Params{H<:AbstractHook}
envs::Vector{Env}
# agents
agents::Vector{Agent}
hooks::Vector{H}
actions::Vector{Tuple{Float64,Float64}}
env_params::EnvParams
n_steps_before_actions_update::Int64
min_distance²::Vector{Float64}
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
function Params(
function Params{H}(
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
)
) where {H<:AbstractHook}
policies = [
gen_policy(
length(rl_params.env_params.state_space),
length(rl_params.env_params.action_space),
) for i in 1:n_particles
]
agents = [
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
]
return new(
Vector{Env}(undef, n_particles),
[Env(env_params, gen_tmp_particle()) for i in 1:n_particles],
agents,
[H() for i in 1:n_particles],
Vector{Tuple{Float64,Float64}}(undef, n_particles),
env_params,
n_steps_before_actions_update,
zeros(n_particles),
fill(SVector(0.0, 0.0), n_particles),
)
end
end
@ -141,17 +174,103 @@ RLBase.action_space(env::Env) = env.params.action_space
RLBase.reward(env::Env) = env.params.reward
function pre_integration_hook!() end
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
action = agent(env)
rl_params.actions[i] = action
function integration_hook() end
agent(PRE_ACT_STAGE, env, action)
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action)
end
function post_integration_hook() end
return nothing
end
function state_hook(
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
)
if rl_params.min_distance²[id1] > distance²
rl_params.min_distance²[id1] = distance²
rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
end
if rl_params.min_distance²[id2] > distance²
rl_params.min_distance²[id2] = distance²
rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
end
return nothing
end
function integration_hook(particle::Particle, rl_params::Params, δt::Float64)
action = rl_params.actions[particle.id]
particle.tmp_c += action[1] * δt
particle.φ += action[2] * δt
return nothing
end
function post_integration_hook(
rl_params::Params, n_particles::Int64, particles::Vector{Particle}
)
env_direction_state = rl_params.env_params.direction_state_space[1]
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
min_distance = sqrt(rl_params.min_distance²[i])
env_distance_state::Union{DistanceState,Nothing} = nothing
for distance_state in rl_params.env_params.distance_state_space
if min_distance in distance_state.interval
env_distance_state = distance_state
break
end
end
if isnothing(env_distance_state)
env.state = (nothing, nothing)
else
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
si, co = sincos(particles[i].φ)
#=
Angle between two vectors
e = (co, si)
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
norm(r⃗₁₂) == min_distance
norm(e) == 1
=#
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
for direction_state in rl_params.env_params.direction_state_space
if direction in direction_state
env_direction_state = direction_state
end
end
env.state = (env_distance_state, env_direction_state)
end
agent(POST_ACT_STAGE, env)
rl_params.hooks[i](POST_ACT_STAGE, agent, env)
end
return nothing
end
function run(
n_episodes::Int64=100,
episode_duration::Float64=5.0,
update_actions_at::Float64=0.1,
n_particles::Int64=100,
n_particles::Int64=10,
)
@assert n_episodes > 0
@assert episode_duration > 0
@ -161,32 +280,61 @@ function run(
Random.seed!(42)
# envs
# agents
# pre_experiment
sim_consts = gen_sim_consts(n_particles, v₀)
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
rl_params = Params(n_particles, env_params, n_steps_before_actions_update)
for episode in 1:n_episodes
# reset
# pre_episode
dir = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
run_sim(
dir;
duration=episode_duration,
seed=rand(1:typemax(Int64)),
rl_params=rl_params,
skin_r=skin_r,
rl_params = Params{TotalRewardPerEpisode}(
n_particles, env_params, n_steps_before_actions_update
)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
end
return nothing
for episode in 1:n_episodes
dir, particles = init_sim_with_sim_consts(; sim_consts, parent_dir="RL")
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
end
for i in 1:n_particles
rl_params.envs[i].particle = particles[i]
rl_params.envs[i].state = (nothing, nothing)
end
rl_params.env_params.reward = 0.0
run_sim(
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
end
end
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
end
return rl_params.hooks
end
end # module

View file

@ -78,6 +78,7 @@ function run_sim(
args = (
v₀=sim_consts.v₀,
δt=sim_consts.δt,
skin_r=sim_consts.skin_r,
skin_r²=sim_consts.skin_r^2,
n_snapshots=n_snapshots,

View file

@ -114,31 +114,31 @@ function init_sim_with_sim_consts(
bundle = Bundle(n_particles, 1)
save_snapshot!(bundle, 1, 0.0, particles)
particles = nothing
dir = exports_dir
if length(parent_dir) > 0
exports_dir *= "/$parent_dir"
dir *= "/$parent_dir"
end
start_datetime = now()
exports_dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
dir *= "/$(start_datetime)_N=$(sim_consts.n_particles)_v=$(sim_consts.v₀)_#$(rand(1000:9999))"
if length(comment) > 0
exports_dir *= "_$comment"
dir *= "_$comment"
end
mkpath(exports_dir)
mkpath(dir)
task = @async write_struct_to_json(sim_consts, "$exports_dir/sim_consts")
task = @async write_struct_to_json(sim_consts, "$dir/sim_consts")
save_bundle(exports_dir, bundle, 1, 0.0)
save_bundle(dir, bundle, 1, 0.0)
runs_dir = "$exports_dir/runs"
runs_dir = "$dir/runs"
mkpath(runs_dir)
wait(task)
return exports_dir
return (dir, particles)
end
function init_sim(;
@ -155,5 +155,5 @@ function init_sim(;
n_particles, v₀, δt, packing_ratio, skin_to_interaction_r_ratio
)
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)
return init_sim_with_sim_consts(sim_consts, exports_dir, parent_dir, comment)[1]
end

View file

@ -41,7 +41,7 @@ function update_verlet_lists!(args, cl)
end
function euler!(
args, integration_hook::F, actions::Vector{Tuple{Float64,Float64}}
args, state_hook::F, integration_hook::F, rl_params::Union{ReCoRL.Params,Nothing}
) where {F<:Function}
for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1]
@ -55,6 +55,8 @@ function euler!(
p1_c, p2.c, args.interaction_r², args.half_box_len
)
state_hook(id1, id2, r⃗₁₂, distance², rl_params)
if overlapping
c = args.c₁ / (distance²^4) * (args.c₂ / (distance²^3) - 1.0)
@ -76,7 +78,7 @@ function euler!(
restrict_coordinates!(p, args.half_box_len)
integration_hook(p, actions)
integration_hook(p, rl_params, args.δt)
p.c = p.tmp_c
end
@ -86,6 +88,13 @@ end
wait(::Nothing) = nothing
gen_run_hooks(::Nothing, args...) = false
function gen_run_hooks(rl_params::ReCoRL.Params, integration_step::Int64)
return (itegration_step == 1) ||
(integration_step % rl_params.n_steps_before_actions_update == 0)
end
function simulate(
args,
δt::Float64,
@ -108,7 +117,8 @@ function simulate(
cl = CellList(args.particles_c, args.box; parallel=false)
cl = update_verlet_lists!(args, cl)
update_actions = true
run_hooks = false
state_hook = empty_hook
start_time = now()
println("Started simulation at $start_time.")
@ -134,16 +144,18 @@ function simulate(
cl = update_verlet_lists!(args, cl)
end
update_actions = integration_step % rl_params.n_steps_before_actions_update == 0
run_hooks = gen_run_hooks(rl_params, integration_step)
if update_actions
pre_integration_hook!(rl_params)
if run_hooks
pre_integration_hook!(rl_params, args.n_particles)
state_hook = ReCoRL.state_hook
end
euler!(args, integration_hook, rl.params.actions)
euler!(args, state_hook, integration_hook, rl.params.actions)
if update_actions
post_integration_hook(rl_params)
if run_hooks
post_integration_hook(rl_params, args.n_particles, args.particles)
state_hook = empty_hook
end
end