mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added compass for circle
This commit is contained in:
parent
9353621edb
commit
6f7d7fef37
1 changed files with 112 additions and 99 deletions
211
src/RL.jl
211
src/RL.jl
|
@ -2,6 +2,8 @@ module RL
|
|||
|
||||
export run_rl
|
||||
|
||||
using Base: OneTo
|
||||
|
||||
using ReinforcementLearning
|
||||
using Flux: InvDecay
|
||||
using Intervals
|
||||
|
@ -13,34 +15,57 @@ using ProgressMeter: @showprogress
|
|||
using ..ReCo: ReCo, Particle, angle2, center_of_mass
|
||||
|
||||
const INITIAL_REWARD = 0.0
|
||||
const INITIAL_STATE_IND = 1
|
||||
|
||||
function angle_state_space(n_angle_states::Int64)
|
||||
angle_range = range(; start=-π, stop=π, length=n_angle_states + 1)
|
||||
|
||||
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
||||
|
||||
@simd for i in 1:n_angle_states
|
||||
if i == 1
|
||||
bound = Closed
|
||||
else
|
||||
bound = Open
|
||||
end
|
||||
|
||||
angle_state_space[i] = Interval{Float64,bound,Closed}(
|
||||
angle_range[i], angle_range[i + 1]
|
||||
)
|
||||
end
|
||||
|
||||
return angle_state_space
|
||||
end
|
||||
|
||||
mutable struct Env <: AbstractEnv
|
||||
n_actions::Int64
|
||||
action_space::Vector{SVector{2,Float64}}
|
||||
action_ind_space::Vector{Int64}
|
||||
action_ind_space::OneTo{Int64}
|
||||
|
||||
distance_state_space::Vector{Interval}
|
||||
angle_state_space::Vector{Interval}
|
||||
direction_angle_state_space::Vector{Interval}
|
||||
position_angle_state_space::Vector{Interval}
|
||||
|
||||
n_states::Int64
|
||||
state_space::Vector{SVector{2,Interval}}
|
||||
state_ind_space::Vector{Int64}
|
||||
state_space::Vector{SVector{3,Interval}}
|
||||
state_ind_space::OneTo{Int64}
|
||||
state_ind::Int64
|
||||
|
||||
reward::Float64
|
||||
terminated::Bool
|
||||
|
||||
center_of_mass::SVector{2,Float64}
|
||||
center_of_mass::SVector{2,Float64} # TODO: Use or remove
|
||||
|
||||
function Env(
|
||||
max_distance::Float64;
|
||||
function Env(;
|
||||
max_distance::Float64,
|
||||
min_distance::Float64=0.0,
|
||||
n_v_actions::Int64=3,
|
||||
n_v_actions::Int64=2,
|
||||
n_ω_actions::Int64=3,
|
||||
max_v::Float64=40.0,
|
||||
max_ω::Float64=π / 2,
|
||||
n_distance_states::Int64=3,
|
||||
n_angle_states::Int64=4,
|
||||
n_distance_states::Int64=4,
|
||||
n_direction_angle_states::Int64=3,
|
||||
n_position_angle_states::Int64=4,
|
||||
)
|
||||
@assert min_distance >= 0.0
|
||||
@assert max_distance > min_distance
|
||||
|
@ -48,9 +73,12 @@ mutable struct Env <: AbstractEnv
|
|||
@assert n_ω_actions > 1
|
||||
@assert max_v > 0
|
||||
@assert max_ω > 0
|
||||
@assert n_distance_states > 1
|
||||
@assert n_direction_angle_states > 1
|
||||
@assert n_position_angle_states > 1
|
||||
|
||||
v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v
|
||||
ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω
|
||||
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
|
||||
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
|
||||
|
||||
n_actions = n_v_actions * n_ω_actions
|
||||
|
||||
|
@ -64,10 +92,11 @@ mutable struct Env <: AbstractEnv
|
|||
end
|
||||
end
|
||||
|
||||
action_ind_space = collect(1:n_actions)
|
||||
action_ind_space = OneTo(n_actions)
|
||||
|
||||
distance_range =
|
||||
min_distance:((max_distance - min_distance) / n_distance_states):max_distance
|
||||
distance_range = range(;
|
||||
start=min_distance, stop=max_distance, length=n_distance_states + 1
|
||||
)
|
||||
|
||||
distance_state_space = Vector{Interval}(undef, n_distance_states)
|
||||
|
||||
|
@ -83,50 +112,38 @@ mutable struct Env <: AbstractEnv
|
|||
)
|
||||
end
|
||||
|
||||
angle_range = (-π):(2 * π / n_angle_states):π
|
||||
direction_angle_state_space = angle_state_space(n_direction_angle_states)
|
||||
position_angle_state_space = angle_state_space(n_position_angle_states)
|
||||
|
||||
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
||||
n_states = n_distance_states * n_direction_angle_states * n_position_angle_states
|
||||
|
||||
@simd for i in 1:n_angle_states
|
||||
if i == 1
|
||||
bound = Closed
|
||||
else
|
||||
bound = Open
|
||||
end
|
||||
|
||||
angle_state_space[i] = Interval{Float64,bound,Closed}(
|
||||
angle_range[i], angle_range[i + 1]
|
||||
)
|
||||
end
|
||||
|
||||
n_states = n_distance_states * n_angle_states + 1
|
||||
|
||||
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
||||
state_space = Vector{SVector{3,Interval}}(undef, n_states)
|
||||
|
||||
ind = 1
|
||||
for distance_state in distance_state_space
|
||||
for angle_state in angle_state_space
|
||||
state_space[ind] = SVector(distance_state, angle_state)
|
||||
ind += 1
|
||||
for direction_angle_state in direction_angle_state_space
|
||||
for position_angle_state in position_angle_state_space
|
||||
state_space[ind] = SVector(
|
||||
distance_state, direction_angle_state, position_angle_state
|
||||
)
|
||||
ind += 1
|
||||
end
|
||||
end
|
||||
end
|
||||
# Last state is SVector(nothing, nothing)
|
||||
|
||||
state_ind_space = collect(1:n_states)
|
||||
|
||||
# initial_state = SVector(nothing, nothing)
|
||||
initial_state_ind = n_states
|
||||
state_ind_space = OneTo(n_states)
|
||||
|
||||
return new(
|
||||
n_actions,
|
||||
action_space,
|
||||
action_ind_space,
|
||||
distance_state_space,
|
||||
angle_state_space,
|
||||
direction_angle_state_space,
|
||||
position_angle_state_space,
|
||||
n_states,
|
||||
state_space,
|
||||
state_ind_space,
|
||||
initial_state_ind,
|
||||
INITIAL_STATE_IND,
|
||||
INITIAL_REWARD,
|
||||
false,
|
||||
SVector(0.0, 0.0),
|
||||
|
@ -171,9 +188,6 @@ struct Params{H<:AbstractHook}
|
|||
half_box_len::Float64
|
||||
max_elliptic_distance::Float64
|
||||
|
||||
local_centers_of_mass::Vector{SVector{2,Float64}}
|
||||
updated_local_center_of_mass::Vector{Bool}
|
||||
|
||||
function Params(
|
||||
env::Env,
|
||||
agent::Agent,
|
||||
|
@ -200,37 +214,32 @@ struct Params{H<:AbstractHook}
|
|||
n_particles,
|
||||
half_box_len,
|
||||
max_elliptic_distance,
|
||||
fill(SVector(0.0, 0.0), n_particles),
|
||||
falses(n_particles),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
function pre_integration_hook(rl_params::Params)
|
||||
@simd for id in 1:(rl_params.n_particles)
|
||||
rl_params.local_centers_of_mass[id] = SVector(0.0, 0.0)
|
||||
rl_params.updated_local_center_of_mass[id] = false
|
||||
end
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function state_update_helper_hook(
|
||||
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||
)
|
||||
rl_params.local_centers_of_mass[id1] += r⃗₁₂
|
||||
rl_params.local_centers_of_mass[id2] -= r⃗₁₂
|
||||
|
||||
rl_params.updated_local_center_of_mass[id1] = true
|
||||
rl_params.updated_local_center_of_mass[id2] = true
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function get_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{2,Interval}}
|
||||
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{3,Interval}}
|
||||
return findfirst(x -> x == state, state_space)
|
||||
end
|
||||
|
||||
function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval
|
||||
for state in state_space
|
||||
if value in state
|
||||
return state
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
||||
@turbo for id in 1:(rl_params.n_particles)
|
||||
rl_params.old_states_ind[id] = rl_params.states_ind[id]
|
||||
|
@ -238,39 +247,29 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
|||
|
||||
env = rl_params.env
|
||||
|
||||
env_distance_state = env.distance_state_space[1]
|
||||
env_angle_state = env.angle_state_space[1]
|
||||
state_ind = 0
|
||||
|
||||
for id in 1:(rl_params.n_particles)
|
||||
if !rl_params.updated_local_center_of_mass[id]
|
||||
state_ind = env.n_states
|
||||
else
|
||||
local_center_of_mass = rl_params.local_centers_of_mass[id]
|
||||
particle = particles[id]
|
||||
|
||||
distance = sqrt(local_center_of_mass[1]^2 + local_center_of_mass[2]^2)
|
||||
distance = sqrt(particle.c[1]^2 + particle.c[2]^2)
|
||||
|
||||
for distance_state in env.distance_state_space
|
||||
if distance in distance_state
|
||||
env_distance_state = distance_state
|
||||
break
|
||||
end
|
||||
end
|
||||
distance_state = find_state_interval(distance, env.distance_state_space)
|
||||
|
||||
si, co = sincos(particles[id].φ)
|
||||
si, co = sincos(particles[id].φ)
|
||||
|
||||
angle = angle2(SVector(co, si), local_center_of_mass)
|
||||
direction_angle = angle2(SVector(co, si), -particle.c)
|
||||
position_angle = atan(particle.c[2], particle.c[1])
|
||||
|
||||
for angle_state in env.angle_state_space
|
||||
if angle in angle_state
|
||||
env_angle_state = angle_state
|
||||
break
|
||||
end
|
||||
end
|
||||
direction_angle_state = find_state_interval(
|
||||
direction_angle, env.direction_angle_state_space
|
||||
)
|
||||
position_angle_state = find_state_interval(
|
||||
position_angle, env.position_angle_state_space
|
||||
)
|
||||
|
||||
state = SVector{2,Interval}(env_distance_state, env_angle_state)
|
||||
state_ind = get_state_ind(state, env.state_space)
|
||||
end
|
||||
state = SVector{3,Interval}(
|
||||
distance_state, direction_angle_state, position_angle_state
|
||||
)
|
||||
state_ind = find_state_ind(state, env.state_space)
|
||||
|
||||
rl_params.states_ind[id] = state_ind
|
||||
end
|
||||
|
@ -284,6 +283,14 @@ function get_env_agent_hook(rl_params::Params)
|
|||
return (rl_params.env, rl_params.agent, rl_params.hook)
|
||||
end
|
||||
|
||||
function update_reward!(env::Env, rl_params::Params, particle::Particle)
|
||||
env.reward =
|
||||
-(particle.c[1]^2 + particle.c[2]^2) /
|
||||
(rl_params.max_elliptic_distance^2 * rl_params.n_particles)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
||||
function update_table_and_actions_hook(
|
||||
rl_params::Params, particle::Particle, first_integration_step::Bool
|
||||
)
|
||||
|
@ -305,13 +312,7 @@ function update_table_and_actions_hook(
|
|||
env.state_ind = rl_params.states_ind[id]
|
||||
|
||||
# Update reward
|
||||
vec_to_center_of_mass = ReCo.minimum_image(
|
||||
particle.c - env.center_of_mass, rl_params.half_box_len
|
||||
)
|
||||
|
||||
env.reward =
|
||||
-(vec_to_center_of_mass[1]^2 + vec_to_center_of_mass[2]^2) /
|
||||
rl_params.max_elliptic_distance / rl_params.n_particles
|
||||
update_reward!(env, rl_params, particle)
|
||||
|
||||
# Post act
|
||||
agent(POST_ACT_STAGE, env)
|
||||
|
@ -343,14 +344,24 @@ function act_hook(
|
|||
return nothing
|
||||
end
|
||||
|
||||
function gen_agent(n_states::Int64, n_actions::Int64, ϵ::Float64)
|
||||
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
||||
# TODO: Optimize warmup and decay
|
||||
warmup_steps = 200_000
|
||||
decay_steps = 1_000_000
|
||||
|
||||
policy = QBasedPolicy(;
|
||||
learner=MonteCarloLearner(;
|
||||
approximator=TabularQApproximator(;
|
||||
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
||||
),
|
||||
),
|
||||
explorer=EpsilonGreedyExplorer(ϵ),
|
||||
explorer=EpsilonGreedyExplorer(;
|
||||
kind=:linear,
|
||||
ϵ_init=1.0,
|
||||
ϵ_stable=ϵ_stable,
|
||||
warmup_steps=warmup_steps,
|
||||
decay_steps=decay_steps,
|
||||
),
|
||||
)
|
||||
|
||||
return Agent(; policy=policy, trajectory=VectorSARTTrajectory())
|
||||
|
@ -363,7 +374,7 @@ function run_rl(;
|
|||
update_actions_at::Float64=0.1,
|
||||
n_particles::Int64=100,
|
||||
seed::Int64=42,
|
||||
ϵ::Float64=0.01,
|
||||
ϵ_stable::Float64=0.0001,
|
||||
parent_dir::String="",
|
||||
)
|
||||
@assert 0.0 <= goal_shape_ratio <= 1.0
|
||||
|
@ -371,19 +382,19 @@ function run_rl(;
|
|||
@assert episode_duration > 0
|
||||
@assert update_actions_at in 0.001:0.001:episode_duration
|
||||
@assert n_particles > 0
|
||||
@assert 0.0 < ϵ < 1.0
|
||||
@assert 0.0 < ϵ_stable < 1.0
|
||||
|
||||
# Setup
|
||||
Random.seed!(seed)
|
||||
|
||||
sim_consts = ReCo.gen_sim_consts(
|
||||
n_particles, 0.0; skin_to_interaction_r_ratio=1.8, packing_ratio=0.15
|
||||
n_particles, 0.0; skin_to_interaction_r_ratio=1.5, packing_ratio=0.22
|
||||
)
|
||||
n_particles = sim_consts.n_particles
|
||||
|
||||
env = Env(sim_consts.skin_r)
|
||||
env = Env(; max_distance=sqrt(2) * sim_consts.half_box_len)
|
||||
|
||||
agent = gen_agent(env.n_states, env.n_actions, ϵ)
|
||||
agent = gen_agent(env.n_states, env.n_actions, ϵ_stable)
|
||||
|
||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||
|
||||
|
@ -426,7 +437,9 @@ function run_rl(;
|
|||
hook(POST_EPISODE_STAGE, agent, env)
|
||||
agent(POST_EPISODE_STAGE, env)
|
||||
|
||||
# TODO: Replace with live plot
|
||||
display(hook.rewards)
|
||||
display(agent.policy.explorer.step)
|
||||
end
|
||||
|
||||
# Post experiment
|
||||
|
|
Loading…
Reference in a new issue