1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/src/RL.jl

451 lines
12 KiB
Julia
Raw Normal View History

2021-12-12 23:19:18 +00:00
module RL
2021-12-12 14:29:08 +00:00
2021-12-15 03:45:15 +00:00
export run_rl
2021-12-14 03:03:14 +00:00
2022-01-06 00:48:37 +00:00
using Base: OneTo
2021-12-10 02:16:45 +00:00
using ReinforcementLearning
2021-12-12 17:27:56 +00:00
using Flux: InvDecay
2021-12-12 14:29:08 +00:00
using Intervals
2021-12-12 17:27:56 +00:00
using StaticArrays: SVector
2021-12-14 03:03:14 +00:00
using LoopVectorization: @turbo
2021-12-12 23:19:18 +00:00
using Random: Random
2021-12-13 01:24:34 +00:00
using ProgressMeter: @showprogress
2021-12-12 23:19:18 +00:00
2021-12-28 16:15:00 +00:00
using ..ReCo: ReCo, Particle, angle2, center_of_mass
2021-12-12 14:29:08 +00:00
2021-12-14 03:03:14 +00:00
const INITIAL_REWARD = 0.0
2022-01-06 00:48:37 +00:00
const INITIAL_STATE_IND = 1
function angle_state_space(n_angle_states::Int64)
angle_range = range(; start=-π, stop=π, length=n_angle_states + 1)
angle_state_space = Vector{Interval}(undef, n_angle_states)
@simd for i in 1:n_angle_states
if i == 1
bound = Closed
else
bound = Open
end
angle_state_space[i] = Interval{Float64,bound,Closed}(
angle_range[i], angle_range[i + 1]
)
end
return angle_state_space
end
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
mutable struct Env <: AbstractEnv
n_actions::Int64
action_space::Vector{SVector{2,Float64}}
2022-01-06 00:48:37 +00:00
action_ind_space::OneTo{Int64}
2021-12-14 03:03:14 +00:00
2021-12-16 13:54:52 +00:00
distance_state_space::Vector{Interval}
2022-01-06 00:48:37 +00:00
direction_angle_state_space::Vector{Interval}
position_angle_state_space::Vector{Interval}
2021-12-20 23:31:44 +00:00
2021-12-14 03:03:14 +00:00
n_states::Int64
2022-01-06 00:48:37 +00:00
state_space::Vector{SVector{3,Interval}}
state_ind_space::OneTo{Int64}
2021-12-20 23:31:44 +00:00
state_ind::Int64
2021-12-14 03:03:14 +00:00
2021-12-10 02:16:45 +00:00
reward::Float64
2021-12-20 23:31:44 +00:00
terminated::Bool
2021-12-10 02:16:45 +00:00
2022-01-06 00:48:37 +00:00
center_of_mass::SVector{2,Float64} # TODO: Use or remove
2021-12-28 16:15:00 +00:00
2022-01-06 00:48:37 +00:00
function Env(;
max_distance::Float64,
2021-12-28 22:39:24 +00:00
min_distance::Float64=0.0,
2022-01-06 00:48:37 +00:00
n_v_actions::Int64=2,
2021-12-28 22:39:24 +00:00
n_ω_actions::Int64=3,
max_v::Float64=40.0,
2021-12-16 12:50:38 +00:00
max_ω::Float64=π / 2,
2022-01-06 00:48:37 +00:00
n_distance_states::Int64=4,
n_direction_angle_states::Int64=3,
n_position_angle_states::Int64=4,
2021-12-10 02:16:45 +00:00
)
2021-12-28 22:39:24 +00:00
@assert min_distance >= 0.0
2021-12-12 14:29:08 +00:00
@assert max_distance > min_distance
2021-12-10 02:16:45 +00:00
@assert n_v_actions > 1
2021-12-12 23:19:18 +00:00
@assert n_ω_actions > 1
2021-12-10 02:16:45 +00:00
@assert max_v > 0
@assert max_ω > 0
2022-01-06 00:48:37 +00:00
@assert n_distance_states > 1
@assert n_direction_angle_states > 1
@assert n_position_angle_states > 1
2021-12-10 02:16:45 +00:00
2022-01-06 00:48:37 +00:00
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
2021-12-10 02:16:45 +00:00
n_actions = n_v_actions * n_ω_actions
2021-12-20 23:31:44 +00:00
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
2021-12-10 02:16:45 +00:00
ind = 1
for v in v_action_space
for ω in ω_action_space
2021-12-20 23:31:44 +00:00
action_space[ind] = SVector(v, ω)
2021-12-10 02:16:45 +00:00
ind += 1
end
end
2022-01-06 00:48:37 +00:00
action_ind_space = OneTo(n_actions)
2021-12-13 01:24:34 +00:00
2022-01-06 00:48:37 +00:00
distance_range = range(;
start=min_distance, stop=max_distance, length=n_distance_states + 1
)
2021-12-10 02:16:45 +00:00
2021-12-16 13:54:52 +00:00
distance_state_space = Vector{Interval}(undef, n_distance_states)
2021-12-10 02:16:45 +00:00
2021-12-14 03:03:14 +00:00
@simd for i in 1:n_distance_states
2021-12-12 14:29:08 +00:00
if i == 1
bound = Closed
else
bound = Open
end
2021-12-16 13:54:52 +00:00
distance_state_space[i] = Interval{Float64,bound,Closed}(
2021-12-12 14:29:08 +00:00
distance_range[i], distance_range[i + 1]
)
end
2022-01-06 00:48:37 +00:00
direction_angle_state_space = angle_state_space(n_direction_angle_states)
position_angle_state_space = angle_state_space(n_position_angle_states)
2021-12-16 13:54:52 +00:00
2022-01-06 00:48:37 +00:00
n_states = n_distance_states * n_direction_angle_states * n_position_angle_states
2021-12-12 14:29:08 +00:00
2022-01-06 00:48:37 +00:00
state_space = Vector{SVector{3,Interval}}(undef, n_states)
2021-12-10 02:16:45 +00:00
ind = 1
2021-12-12 14:29:08 +00:00
for distance_state in distance_state_space
2022-01-06 00:48:37 +00:00
for direction_angle_state in direction_angle_state_space
for position_angle_state in position_angle_state_space
state_space[ind] = SVector(
distance_state, direction_angle_state, position_angle_state
)
ind += 1
end
2021-12-10 02:16:45 +00:00
end
end
2022-01-06 00:48:37 +00:00
state_ind_space = OneTo(n_states)
2021-12-20 23:31:44 +00:00
2021-12-12 14:29:08 +00:00
return new(
2021-12-20 23:31:44 +00:00
n_actions,
2021-12-13 01:24:34 +00:00
action_space,
2021-12-14 03:03:14 +00:00
action_ind_space,
2021-12-13 01:24:34 +00:00
distance_state_space,
2022-01-06 00:48:37 +00:00
direction_angle_state_space,
position_angle_state_space,
2021-12-20 23:31:44 +00:00
n_states,
2021-12-13 01:24:34 +00:00
state_space,
2021-12-14 03:03:14 +00:00
state_ind_space,
2022-01-06 00:48:37 +00:00
INITIAL_STATE_IND,
2021-12-14 03:03:14 +00:00
INITIAL_REWARD,
2021-12-20 23:31:44 +00:00
false,
2021-12-28 16:15:00 +00:00
SVector(0.0, 0.0),
2021-12-12 14:29:08 +00:00
)
2021-12-10 02:16:45 +00:00
end
end
2021-12-20 23:31:44 +00:00
function reset!(env::Env)
env.state_ind = env.n_states
env.reward = INITIAL_REWARD
env.terminated = false
2021-12-14 03:03:14 +00:00
return nothing
end
2021-12-20 23:31:44 +00:00
RLBase.state_space(env::Env) = env.state_ind_space
2021-12-10 02:16:45 +00:00
2021-12-20 23:31:44 +00:00
RLBase.state(env::Env) = env.state_ind
2021-12-10 02:16:45 +00:00
2021-12-20 23:31:44 +00:00
RLBase.action_space(env::Env) = env.action_ind_space
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
RLBase.reward(env::Env) = env.reward
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
RLBase.is_terminated(env::Env) = env.terminated
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
struct Params{H<:AbstractHook}
env::Env
agent::Agent
hook::H
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
old_states_ind::Vector{Int64}
states_ind::Vector{Int64}
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
actions::Vector{SVector{2,Float64}}
actions_ind::Vector{Int64}
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
n_steps_before_actions_update::Int64
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
goal_shape_ratio::Float64
2021-12-12 17:27:56 +00:00
2021-12-20 23:31:44 +00:00
n_particles::Int64
2021-12-28 16:15:00 +00:00
half_box_len::Float64
max_elliptic_distance::Float64
2021-12-20 23:31:44 +00:00
function Params(
env::Env,
agent::Agent,
hook::H,
2021-12-13 01:24:34 +00:00
n_steps_before_actions_update::Int64,
goal_shape_ratio::Float64,
2021-12-20 23:31:44 +00:00
n_particles::Int64,
2021-12-28 16:15:00 +00:00
half_box_len::Float64,
2021-12-12 17:27:56 +00:00
) where {H<:AbstractHook}
2021-12-28 16:15:00 +00:00
max_elliptic_distance = sqrt(2) * half_box_len
2021-12-20 23:31:44 +00:00
n_states = env.n_states
return new{H}(
env,
agent,
hook,
fill(0, n_particles),
fill(n_states, n_particles),
fill(SVector(0.0, 0.0), n_particles),
fill(0, n_particles),
2021-12-12 14:29:08 +00:00
n_steps_before_actions_update,
2021-12-13 01:24:34 +00:00
goal_shape_ratio,
2021-12-20 23:31:44 +00:00
n_particles,
2021-12-28 16:15:00 +00:00
half_box_len,
max_elliptic_distance,
2021-12-12 14:29:08 +00:00
)
end
end
2021-12-20 23:31:44 +00:00
function pre_integration_hook(rl_params::Params)
2021-12-12 17:27:56 +00:00
return nothing
end
2021-12-20 23:31:44 +00:00
function state_update_helper_hook(
2021-12-28 22:39:24 +00:00
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
2021-12-12 17:27:56 +00:00
)
return nothing
end
2022-01-06 00:48:37 +00:00
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{3,Interval}}
2021-12-20 23:31:44 +00:00
return findfirst(x -> x == state, state_space)
2021-12-12 17:27:56 +00:00
end
2022-01-06 00:48:37 +00:00
function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval
for state in state_space
if value in state
return state
end
end
end
2021-12-28 16:15:00 +00:00
function state_update_hook(rl_params::Params, particles::Vector{Particle})
@turbo for id in 1:(rl_params.n_particles)
rl_params.old_states_ind[id] = rl_params.states_ind[id]
2021-12-20 23:31:44 +00:00
end
2021-12-14 03:03:14 +00:00
2021-12-20 23:31:44 +00:00
env = rl_params.env
2021-12-13 01:24:34 +00:00
2021-12-28 16:15:00 +00:00
for id in 1:(rl_params.n_particles)
2022-01-06 00:48:37 +00:00
particle = particles[id]
2021-12-12 17:27:56 +00:00
2022-01-06 00:48:37 +00:00
distance = sqrt(particle.c[1]^2 + particle.c[2]^2)
2021-12-14 03:03:14 +00:00
2022-01-06 00:48:37 +00:00
distance_state = find_state_interval(distance, env.distance_state_space)
2021-12-12 17:27:56 +00:00
2022-01-06 00:48:37 +00:00
si, co = sincos(particles[id].φ)
2021-12-12 17:27:56 +00:00
2022-01-06 00:48:37 +00:00
direction_angle = angle2(SVector(co, si), -particle.c)
position_angle = atan(particle.c[2], particle.c[1])
2021-12-12 17:27:56 +00:00
2022-01-06 00:48:37 +00:00
direction_angle_state = find_state_interval(
direction_angle, env.direction_angle_state_space
)
position_angle_state = find_state_interval(
position_angle, env.position_angle_state_space
)
2021-12-12 17:27:56 +00:00
2022-01-06 00:48:37 +00:00
state = SVector{3,Interval}(
distance_state, direction_angle_state, position_angle_state
)
state_ind = find_state_ind(state, env.state_space)
2021-12-12 17:27:56 +00:00
2021-12-28 16:15:00 +00:00
rl_params.states_ind[id] = state_ind
2021-12-20 23:31:44 +00:00
end
2021-12-28 16:15:00 +00:00
env.center_of_mass = center_of_mass(particles, rl_params.half_box_len)
2021-12-20 23:31:44 +00:00
return nothing
end
function get_env_agent_hook(rl_params::Params)
return (rl_params.env, rl_params.agent, rl_params.hook)
end
2022-01-06 00:48:37 +00:00
function update_reward!(env::Env, rl_params::Params, particle::Particle)
env.reward =
-(particle.c[1]^2 + particle.c[2]^2) /
(rl_params.max_elliptic_distance^2 * rl_params.n_particles)
return nothing
end
2021-12-20 23:31:44 +00:00
function update_table_and_actions_hook(
rl_params::Params, particle::Particle, first_integration_step::Bool
)
env, agent, hook = get_env_agent_hook(rl_params)
id = particle.id
if !first_integration_step
# Old state
env.state_ind = rl_params.old_states_ind[id]
action_ind = rl_params.actions_ind[id]
# Pre act
agent(PRE_ACT_STAGE, env, action_ind)
hook(PRE_ACT_STAGE, agent, env, action_ind)
# Update to current state
env.state_ind = rl_params.states_ind[id]
# Update reward
2022-01-06 00:48:37 +00:00
update_reward!(env, rl_params, particle)
2021-12-20 23:31:44 +00:00
2021-12-14 03:03:14 +00:00
# Post act
2021-12-12 17:27:56 +00:00
agent(POST_ACT_STAGE, env)
2021-12-14 03:03:14 +00:00
hook(POST_ACT_STAGE, agent, env)
2021-12-12 17:27:56 +00:00
end
2021-12-20 23:31:44 +00:00
# Update action
action_ind = agent(env)
action = env.action_space[action_ind]
rl_params.actions[id] = action
rl_params.actions_ind[id] = action_ind
2021-12-28 16:15:00 +00:00
2021-12-20 23:31:44 +00:00
return nothing
end
act_hook(::Nothing, args...) = nothing
function act_hook(
rl_params::Params, particle::Particle, δt::Float64, si::Float64, co::Float64
)
# Apply action
action = rl_params.actions[particle.id]
vδt = action[1] * δt
particle.tmp_c += SVector(vδt * co, vδt * si)
particle.φ += action[2] * δt
2021-12-12 17:27:56 +00:00
return nothing
end
2021-12-12 14:29:08 +00:00
2022-01-06 00:48:37 +00:00
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
# TODO: Optimize warmup and decay
warmup_steps = 200_000
decay_steps = 1_000_000
2021-12-20 23:31:44 +00:00
policy = QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
),
),
2022-01-06 00:48:37 +00:00
explorer=EpsilonGreedyExplorer(;
kind=:linear,
ϵ_init=1.0,
ϵ_stable=ϵ_stable,
warmup_steps=warmup_steps,
decay_steps=decay_steps,
),
2021-12-20 23:31:44 +00:00
)
return Agent(; policy=policy, trajectory=VectorSARTTrajectory())
end
2021-12-15 03:45:15 +00:00
function run_rl(;
2021-12-13 01:24:34 +00:00
goal_shape_ratio::Float64,
2021-12-28 16:15:00 +00:00
n_episodes::Int64=200,
2021-12-14 03:03:14 +00:00
episode_duration::Float64=50.0,
2021-12-28 16:15:00 +00:00
update_actions_at::Float64=0.1,
2021-12-13 01:24:34 +00:00
n_particles::Int64=100,
2021-12-14 03:03:14 +00:00
seed::Int64=42,
2022-01-06 00:48:37 +00:00
ϵ_stable::Float64=0.0001,
2021-12-28 16:15:00 +00:00
parent_dir::String="",
2021-12-12 14:29:08 +00:00
)
2021-12-13 01:24:34 +00:00
@assert 0.0 <= goal_shape_ratio <= 1.0
2021-12-12 14:29:08 +00:00
@assert n_episodes > 0
@assert episode_duration > 0
2021-12-28 16:15:00 +00:00
@assert update_actions_at in 0.001:0.001:episode_duration
2021-12-12 14:29:08 +00:00
@assert n_particles > 0
2022-01-06 00:48:37 +00:00
@assert 0.0 < ϵ_stable < 1.0
2021-12-12 14:29:08 +00:00
2021-12-14 03:03:14 +00:00
# Setup
Random.seed!(seed)
2021-12-10 02:16:45 +00:00
2021-12-28 16:15:00 +00:00
sim_consts = ReCo.gen_sim_consts(
2022-01-06 00:48:37 +00:00
n_particles, 0.0; skin_to_interaction_r_ratio=1.5, packing_ratio=0.22
2021-12-28 16:15:00 +00:00
)
2021-12-12 23:19:18 +00:00
n_particles = sim_consts.n_particles
2021-12-12 14:29:08 +00:00
2022-01-06 00:48:37 +00:00
env = Env(; max_distance=sqrt(2) * sim_consts.half_box_len)
2021-12-20 23:31:44 +00:00
2022-01-06 00:48:37 +00:00
agent = gen_agent(env.n_states, env.n_actions, ϵ_stable)
2021-12-12 14:29:08 +00:00
2021-12-12 23:19:18 +00:00
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
2021-12-20 23:31:44 +00:00
hook = TotalRewardPerEpisode()
rl_params = Params(
2021-12-28 16:15:00 +00:00
env,
agent,
hook,
n_steps_before_actions_update,
goal_shape_ratio,
n_particles,
sim_consts.half_box_len,
2021-12-12 17:27:56 +00:00
)
2021-12-28 16:15:00 +00:00
parent_dir = "RL" * parent_dir
2021-12-14 03:03:14 +00:00
# Pre experiment
2021-12-20 23:31:44 +00:00
hook(PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
2021-12-12 14:29:08 +00:00
2021-12-13 01:24:34 +00:00
@showprogress 0.6 for episode in 1:n_episodes
2021-12-28 16:15:00 +00:00
dir = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir=parent_dir)
2021-12-12 17:27:56 +00:00
2021-12-14 03:03:14 +00:00
# Reset
2021-12-20 23:31:44 +00:00
reset!(env)
2021-12-12 14:29:08 +00:00
2021-12-14 03:03:14 +00:00
# Pre espisode
2021-12-20 23:31:44 +00:00
hook(PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
2021-12-14 03:03:14 +00:00
# Episode
2021-12-16 12:50:38 +00:00
ReCo.run_sim(
2021-12-12 17:27:56 +00:00
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
2021-12-12 14:29:08 +00:00
)
2021-12-12 17:27:56 +00:00
2021-12-20 23:31:44 +00:00
env.terminated = true
2021-12-12 17:27:56 +00:00
2021-12-20 23:31:44 +00:00
# Post episode
hook(POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
2021-12-28 16:15:00 +00:00
2022-01-06 00:48:37 +00:00
# TODO: Replace with live plot
2021-12-28 16:15:00 +00:00
display(hook.rewards)
2022-01-06 00:48:37 +00:00
display(agent.policy.explorer.step)
2021-12-12 14:29:08 +00:00
end
2021-12-14 03:03:14 +00:00
# Post experiment
2021-12-20 23:31:44 +00:00
hook(POST_EXPERIMENT_STAGE, agent, env)
2021-12-12 17:27:56 +00:00
2021-12-13 01:24:34 +00:00
return rl_params
2021-12-12 14:29:08 +00:00
end
2021-12-10 02:16:45 +00:00
2021-12-12 14:29:08 +00:00
end # module