1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00

state and action ind

This commit is contained in:
MoBit 2021-12-13 02:24:34 +01:00
parent 3036c5e65a
commit cfb38c6a19
3 changed files with 89 additions and 34 deletions

View file

@ -5,6 +5,7 @@ using Flux: InvDecay
using Intervals using Intervals
using StaticArrays: SVector using StaticArrays: SVector
using Random: Random using Random: Random
using ProgressMeter: @showprogress
using ..ReCo using ..ReCo
@ -28,17 +29,19 @@ end
mutable struct EnvParams mutable struct EnvParams
action_space::Vector{Tuple{Float64,Float64}} action_space::Vector{Tuple{Float64,Float64}}
action_space_ind::Vector{Int64}
distance_state_space::Vector{DistanceState} distance_state_space::Vector{DistanceState}
direction_state_space::Vector{DirectionState} direction_state_space::Vector{DirectionState}
state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}} state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}}
state_space_ind::Vector{Int64}
reward::Float64 reward::Float64
function EnvParams( function EnvParams(
min_distance::Float64, min_distance::Float64,
max_distance::Float64; max_distance::Float64;
n_v_actions::Int64=2, n_v_actions::Int64=5,
n_ω_actions::Int64=3, n_ω_actions::Int64=5,
max_v::Float64=20.0, max_v::Float64=80.0,
max_ω::Float64=π / 1.5, max_ω::Float64=π / 1.5,
n_distance_states::Int64=3, n_distance_states::Int64=3,
n_direction_states::Int64=4, n_direction_states::Int64=4,
@ -65,6 +68,8 @@ mutable struct EnvParams
end end
end end
action_space_ind = collect(1:n_actions)
distance_range = distance_range =
min_distance:((max_distance - min_distance) / n_distance_states):max_distance min_distance:((max_distance - min_distance) / n_distance_states):max_distance
@ -109,8 +114,18 @@ mutable struct EnvParams
end end
state_space[ind] = (nothing, nothing) state_space[ind] = (nothing, nothing)
state_space_ind = collect(1:n_states)
initial_reward = 0.0
return new( return new(
action_space, distance_state_space, direction_state_space, state_space, 0.0 action_space,
action_space_ind,
distance_state_space,
direction_state_space,
state_space,
state_space_ind,
initial_reward,
) )
end end
end end
@ -118,14 +133,17 @@ end
mutable struct Env <: AbstractEnv mutable struct Env <: AbstractEnv
params::EnvParams params::EnvParams
particle::ReCo.Particle particle::ReCo.Particle
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}} state_ind::Int64
function Env(params::EnvParams, particle::ReCo.Particle) function Env(params::EnvParams, particle::ReCo.Particle)
return new(params, particle, (nothing, nothing)) # initial_state = (nothing, nothing)
initial_state_ind = length(params.state_space_ind)
return new(params, particle, initial_state_ind)
end end
end end
function gen_policy(n_states, n_actions) function gen_policy(n_states::Int64, n_actions::Int64)
return QBasedPolicy(; return QBasedPolicy(;
learner=MonteCarloLearner(; learner=MonteCarloLearner(;
approximator=TabularQApproximator(; approximator=TabularQApproximator(;
@ -145,9 +163,13 @@ struct Params{H<:AbstractHook}
n_steps_before_actions_update::Int64 n_steps_before_actions_update::Int64
min_distance²::Vector{Float64} min_distance²::Vector{Float64}
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}} r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
goal_shape_ratio::Float64
function Params{H}( function Params{H}(
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64 n_particles::Int64,
env_params::EnvParams,
n_steps_before_actions_update::Int64,
goal_shape_ratio::Float64,
) where {H<:AbstractHook} ) where {H<:AbstractHook}
policies = [ policies = [
gen_policy(length(env_params.state_space), length(env_params.action_space)) for gen_policy(length(env_params.state_space), length(env_params.action_space)) for
@ -165,27 +187,32 @@ struct Params{H<:AbstractHook}
n_steps_before_actions_update, n_steps_before_actions_update,
zeros(n_particles), zeros(n_particles),
fill(SVector(0.0, 0.0), n_particles), fill(SVector(0.0, 0.0), n_particles),
goal_shape_ratio,
) )
end end
end end
RLBase.state_space(env::Env) = env.state_space RLBase.state_space(env::Env) = env.params.state_space_ind
RLBase.state(env::Env) = env.state RLBase.state(env::Env) = env.state_ind
RLBase.action_space(env::Env) = env.params.action_space RLBase.action_space(env::Env) = env.params.action_space_ind
RLBase.reward(env::Env) = env.params.reward RLBase.reward(env::Env) = env.params.reward
RLBase.is_terminated(::Env) = false
function pre_integration_hook!(rl_params::Params, n_particles::Int64) function pre_integration_hook!(rl_params::Params, n_particles::Int64)
for i in 1:n_particles for i in 1:n_particles
env = rl_params.envs[i] env = rl_params.envs[i]
agent = rl_params.agents[i] agent = rl_params.agents[i]
action = agent(env)
action_ind = agent(env)
action = rl_params.env_params.action_space[action_ind]
rl_params.actions[i] = action rl_params.actions[i] = action
agent(PRE_ACT_STAGE, env, action) agent(PRE_ACT_STAGE, env, action_ind)
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action) rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action_ind)
end end
return nothing return nothing
@ -209,17 +236,29 @@ function state_hook(
return nothing return nothing
end end
function integration_hook(particle::ReCo.Particle, rl_params::Params, δt::Float64) function integration_hook(
particle::ReCo.Particle, rl_params::Params, δt::Float64, si::Float64, co::Float64
)
action = rl_params.actions[particle.id] action = rl_params.actions[particle.id]
particle.tmp_c += action[1] * δt vδt = action[1] * δt
particle.tmp_c += SVector(vδt * co, vδt * si)
particle.φ += action[2] * δt particle.φ += action[2] * δt
return nothing return nothing
end end
function get_state_ind(
state::T, states::Vector{T}
) where {T<:Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}}
return findfirst(x -> x == state, states)
end
function post_integration_hook( function post_integration_hook(
rl_params::Params, n_particles::Int64, particles::Vector{ReCo.Particle} rl_params::Params,
n_particles::Int64,
particles::Vector{ReCo.Particle},
half_box_len::Float64,
) )
env_direction_state = rl_params.env_params.direction_state_space[1] env_direction_state = rl_params.env_params.direction_state_space[1]
@ -239,7 +278,8 @@ function post_integration_hook(
end end
if isnothing(env_distance_state) if isnothing(env_distance_state)
env.state = (nothing, nothing) # (nothing, nothing)
env.state_ind = length(env.params.state_space)
else else
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i] r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
si, co = sincos(particles[i].φ) si, co = sincos(particles[i].φ)
@ -259,9 +299,17 @@ function post_integration_hook(
end end
end end
env.state = (env_distance_state, env_direction_state) state = (env_distance_state, env_direction_state)
env.state_ind = get_state_ind(state, env.params.state_space)
end end
env.params.reward =
1 -
(
ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) -
rl_params.goal_shape_ratio
)^2
agent(POST_ACT_STAGE, env) agent(POST_ACT_STAGE, env)
rl_params.hooks[i](POST_ACT_STAGE, agent, env) rl_params.hooks[i](POST_ACT_STAGE, agent, env)
end end
@ -270,11 +318,13 @@ function post_integration_hook(
end end
function run(; function run(;
goal_shape_ratio::Float64,
n_episodes::Int64=100, n_episodes::Int64=100,
episode_duration::Float64=5.0, episode_duration::Float64=100.0,
update_actions_at::Float64=0.1, update_actions_at::Float64=0.1,
n_particles::Int64=10, n_particles::Int64=100,
) )
@assert 0.0 <= goal_shape_ratio <= 1.0
@assert n_episodes > 0 @assert n_episodes > 0
@assert episode_duration > 0 @assert episode_duration > 0
@assert update_actions_at in 0.01:0.01:episode_duration @assert update_actions_at in 0.01:0.01:episode_duration
@ -282,7 +332,7 @@ function run(;
Random.seed!(42) Random.seed!(42)
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.5) sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=4.0)
n_particles = sim_consts.n_particles n_particles = sim_consts.n_particles
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r) env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
@ -290,7 +340,7 @@ function run(;
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt) n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
rl_params = Params{TotalRewardPerEpisode}( rl_params = Params{TotalRewardPerEpisode}(
n_particles, env_params, n_steps_before_actions_update n_particles, env_params, n_steps_before_actions_update, goal_shape_ratio
) )
for i in 1:n_particles for i in 1:n_particles
@ -301,7 +351,7 @@ function run(;
agent(PRE_EXPERIMENT_STAGE, env) agent(PRE_EXPERIMENT_STAGE, env)
end end
for episode in 1:n_episodes @showprogress 0.6 for episode in 1:n_episodes
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL") dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
for i in 1:n_particles for i in 1:n_particles
@ -314,7 +364,7 @@ function run(;
for i in 1:n_particles for i in 1:n_particles
rl_params.envs[i].particle = particles[i] rl_params.envs[i].particle = particles[i]
rl_params.envs[i].state = (nothing, nothing) rl_params.envs[i].state_ind = length(rl_params.env_params.state_space)
end end
rl_params.env_params.reward = 0.0 rl_params.env_params.reward = 0.0
@ -339,7 +389,7 @@ function run(;
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env) rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
end end
return rl_params.hooks return rl_params
end end
end # module end # module

View file

@ -70,16 +70,16 @@ function gen_sim_consts(
max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r max_approach_after_one_integration_step = buffer * (2 * v₀ * δt) / interaction_r
@assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step @assert skin_to_interaction_r_ratio >= 1 + max_approach_after_one_integration_step
skin_r = skin_to_interaction_r_ratio * interaction_r
n_steps_before_verlet_list_update = round( n_steps_before_verlet_list_update = round(
Int64, Int64,
(skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step, (skin_to_interaction_r_ratio - 1) / max_approach_after_one_integration_step,
) )
else else
skin_r = 1.5 * interaction_r
n_steps_before_verlet_list_update = 100 n_steps_before_verlet_list_update = 100
end end
skin_r = skin_to_interaction_r_ratio * interaction_r
grid_n = round(Int64, ceil(sqrt(n_particles))) grid_n = round(Int64, ceil(sqrt(n_particles)))
n_particles = grid_n^2 n_particles = grid_n^2

View file

@ -41,8 +41,11 @@ function update_verlet_lists!(args, cl)
end end
function euler!( function euler!(
args, state_hook::F, integration_hook::F, rl_params::Union{RL.Params,Nothing} args,
) where {F<:Function} state_hook::Function,
integration_hook::Function,
rl_params::Union{RL.Params,Nothing},
)
for id1 in 1:(args.n_particles - 1) for id1 in 1:(args.n_particles - 1)
p1 = args.particles[id1] p1 = args.particles[id1]
p1_c = p1.c p1_c = p1.c
@ -74,11 +77,11 @@ function euler!(
args.v₀δt * si + args.c₃ * rand_normal01(), args.v₀δt * si + args.c₃ * rand_normal01(),
) )
p.φ += args.c₄ * rand_normal01()
restrict_coordinates!(p, args.half_box_len) restrict_coordinates!(p, args.half_box_len)
integration_hook(p, rl_params, args.δt) integration_hook(p, rl_params, args.δt, si, co)
p.φ += args.c₄ * rand_normal01()
p.c = p.tmp_c p.c = p.tmp_c
end end
@ -153,7 +156,9 @@ function simulate(
euler!(args, state_hook, integration_hook, rl_params) euler!(args, state_hook, integration_hook, rl_params)
if run_hooks if run_hooks
post_integration_hook(rl_params, args.n_particles, args.particles) post_integration_hook(
rl_params, args.n_particles, args.particles, args.half_box_len
)
state_hook = empty_hook state_hook = empty_hook
end end
end end