1
0
Fork 0
mirror of https://gitlab.rlp.net/mobitar/ReCo.jl.git synced 2024-09-19 19:01:17 +00:00
ReCo.jl/src/reinforcement_learning.jl

345 lines
9.7 KiB
Julia
Raw Normal View History

2021-12-12 23:19:18 +00:00
module RL
2021-12-12 14:29:08 +00:00
2021-12-10 02:16:45 +00:00
using ReinforcementLearning
2021-12-12 17:27:56 +00:00
using Flux: InvDecay
2021-12-12 14:29:08 +00:00
using Intervals
2021-12-12 17:27:56 +00:00
using StaticArrays: SVector
2021-12-12 23:19:18 +00:00
using Random: Random
using ..ReCo
2021-12-12 14:29:08 +00:00
import Base: run
struct DistanceState{L<:Bound}
interval::Interval{Float64,L,Closed}
function DistanceState{L}(lower::Float64, upper::Float64) where {L<:Bound}
return new(Interval{Float64,L,Closed}(lower, upper))
end
end
struct DirectionState
2021-12-12 23:19:18 +00:00
interval::Interval{Float64,Closed,Open}
2021-12-12 14:29:08 +00:00
function DirectionState(lower::Float64, upper::Float64)
return new(Interval{Float64,Closed,Open}(lower, upper))
end
end
2021-12-10 02:16:45 +00:00
2021-12-12 14:29:08 +00:00
mutable struct EnvParams
2021-12-10 02:16:45 +00:00
action_space::Vector{Tuple{Float64,Float64}}
2021-12-12 14:29:08 +00:00
distance_state_space::Vector{DistanceState}
direction_state_space::Vector{DirectionState}
state_space::Vector{Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}}
2021-12-10 02:16:45 +00:00
reward::Float64
2021-12-12 14:29:08 +00:00
function EnvParams(
min_distance::Float64,
max_distance::Float64;
n_v_actions::Int64=2,
n_ω_actions::Int64=3,
max_v::Float64=20.0,
max_ω::Float64=π / 1.5,
n_distance_states::Int64=3,
n_direction_states::Int64=4,
2021-12-10 02:16:45 +00:00
)
2021-12-12 14:29:08 +00:00
@assert min_distance > 0.0
@assert max_distance > min_distance
2021-12-10 02:16:45 +00:00
@assert n_v_actions > 1
2021-12-12 23:19:18 +00:00
@assert n_ω_actions > 1
2021-12-10 02:16:45 +00:00
@assert max_v > 0
@assert max_ω > 0
v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v
ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω
n_actions = n_v_actions * n_ω_actions
action_space = Vector{Tuple{Float64,Float64}}(undef, n_actions)
ind = 1
for v in v_action_space
for ω in ω_action_space
action_space[ind] = (v, ω)
ind += 1
end
end
2021-12-12 14:29:08 +00:00
distance_range =
min_distance:((max_distance - min_distance) / n_distance_states):max_distance
2021-12-10 02:16:45 +00:00
2021-12-12 14:29:08 +00:00
distance_state_space = Vector{DistanceState}(undef, n_distance_states)
2021-12-10 02:16:45 +00:00
2021-12-12 14:29:08 +00:00
for i in 1:n_distance_states
if i == 1
bound = Closed
else
bound = Open
end
distance_state_space[i] = DistanceState{bound}(
distance_range[i], distance_range[i + 1]
)
end
direction_range = 0.0:(2 * π / n_direction_states):(2 * π)
direction_state_space = Vector{DirectionState}(undef, n_direction_states)
for i in 1:n_direction_states
direction_state_space[i] = DirectionState(
direction_range[i], direction_range[i + 1]
)
end
2021-12-12 23:19:18 +00:00
n_states = n_distance_states * n_direction_states + 1
2021-12-12 17:27:56 +00:00
state_space = Vector{
Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
}(
undef, n_states
)
2021-12-10 02:16:45 +00:00
ind = 1
2021-12-12 14:29:08 +00:00
for distance_state in distance_state_space
for direction_state in direction_state_space
state_space[ind] = (distance_state, direction_state)
2021-12-10 02:16:45 +00:00
ind += 1
end
end
2021-12-12 14:29:08 +00:00
state_space[ind] = (nothing, nothing)
2021-12-10 02:16:45 +00:00
2021-12-12 14:29:08 +00:00
return new(
action_space, distance_state_space, direction_state_space, state_space, 0.0
)
2021-12-10 02:16:45 +00:00
end
end
2021-12-12 14:29:08 +00:00
mutable struct Env <: AbstractEnv
params::EnvParams
2021-12-12 23:19:18 +00:00
particle::ReCo.Particle
2021-12-12 14:29:08 +00:00
state::Tuple{Union{DistanceState,Nothing},Union{DirectionState,Nothing}}
2021-12-10 02:16:45 +00:00
2021-12-12 23:19:18 +00:00
function Env(params::EnvParams, particle::ReCo.Particle)
2021-12-12 14:29:08 +00:00
return new(params, particle, (nothing, nothing))
2021-12-10 02:16:45 +00:00
end
end
2021-12-12 17:27:56 +00:00
function gen_policy(n_states, n_actions)
return QBasedPolicy(;
learner=MonteCarloLearner(;
approximator=TabularQApproximator(;
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
),
),
explorer=EpsilonGreedyExplorer(0.1),
)
end
struct Params{H<:AbstractHook}
2021-12-12 14:29:08 +00:00
envs::Vector{Env}
2021-12-12 17:27:56 +00:00
agents::Vector{Agent}
hooks::Vector{H}
2021-12-12 14:29:08 +00:00
actions::Vector{Tuple{Float64,Float64}}
env_params::EnvParams
n_steps_before_actions_update::Int64
2021-12-12 17:27:56 +00:00
min_distance²::Vector{Float64}
r⃗₁₂_to_min_distance_particle::Vector{SVector{2,Float64}}
2021-12-10 02:16:45 +00:00
2021-12-12 17:27:56 +00:00
function Params{H}(
2021-12-12 14:29:08 +00:00
n_particles::Int64, env_params::EnvParams, n_steps_before_actions_update::Int64
2021-12-12 17:27:56 +00:00
) where {H<:AbstractHook}
policies = [
2021-12-12 23:19:18 +00:00
gen_policy(length(env_params.state_space), length(env_params.action_space)) for
i in 1:n_particles
2021-12-12 17:27:56 +00:00
]
agents = [
Agent(; policy=policy, trajectory=VectorSARTTrajectory()) for policy in policies
]
2021-12-12 14:29:08 +00:00
return new(
2021-12-12 23:19:18 +00:00
[Env(env_params, ReCo.gen_tmp_particle()) for i in 1:n_particles],
2021-12-12 17:27:56 +00:00
agents,
[H() for i in 1:n_particles],
2021-12-12 14:29:08 +00:00
Vector{Tuple{Float64,Float64}}(undef, n_particles),
env_params,
n_steps_before_actions_update,
2021-12-12 17:27:56 +00:00
zeros(n_particles),
fill(SVector(0.0, 0.0), n_particles),
2021-12-12 14:29:08 +00:00
)
end
end
RLBase.state_space(env::Env) = env.state_space
RLBase.state(env::Env) = env.state
RLBase.action_space(env::Env) = env.params.action_space
RLBase.reward(env::Env) = env.params.reward
2021-12-12 17:27:56 +00:00
function pre_integration_hook!(rl_params::Params, n_particles::Int64)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
action = agent(env)
rl_params.actions[i] = action
2021-12-12 14:29:08 +00:00
2021-12-12 17:27:56 +00:00
agent(PRE_ACT_STAGE, env, action)
rl_params.hooks[i](PRE_ACT_STAGE, agent, env, action)
end
2021-12-12 14:29:08 +00:00
2021-12-12 17:27:56 +00:00
return nothing
end
function state_hook(
id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}, distance²::Float64, rl_params::Params
)
if rl_params.min_distance²[id1] > distance²
rl_params.min_distance²[id1] = distance²
rl_params.r⃗₁₂_to_min_distance_particle[id1] = r⃗₁₂
end
if rl_params.min_distance²[id2] > distance²
rl_params.min_distance²[id2] = distance²
rl_params.r⃗₁₂_to_min_distance_particle[id2] = -r⃗₁₂
end
return nothing
end
2021-12-12 23:19:18 +00:00
function integration_hook(particle::ReCo.Particle, rl_params::Params, δt::Float64)
2021-12-12 17:27:56 +00:00
action = rl_params.actions[particle.id]
particle.tmp_c += action[1] * δt
particle.φ += action[2] * δt
return nothing
end
function post_integration_hook(
2021-12-12 23:19:18 +00:00
rl_params::Params, n_particles::Int64, particles::Vector{ReCo.Particle}
2021-12-12 17:27:56 +00:00
)
env_direction_state = rl_params.env_params.direction_state_space[1]
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
min_distance = sqrt(rl_params.min_distance²[i])
env_distance_state::Union{DistanceState,Nothing} = nothing
for distance_state in rl_params.env_params.distance_state_space
if min_distance in distance_state.interval
env_distance_state = distance_state
break
end
end
if isnothing(env_distance_state)
env.state = (nothing, nothing)
else
r⃗₁₂ = rl_params.r⃗₁₂_to_min_distance_particle[i]
si, co = sincos(particles[i].φ)
#=
Angle between two vectors
e = (co, si)
angle = acos(dot(r⃗₁₂, e) / (norm(r⃗₁₂) * norm(e)))
norm(r⃗₁₂) == min_distance
norm(e) == 1
=#
direction = acos((r⃗₁₂[1] * co + r⃗₁₂[2] * si) / min_distance)
for direction_state in rl_params.env_params.direction_state_space
if direction in direction_state
env_direction_state = direction_state
end
end
env.state = (env_distance_state, env_direction_state)
end
agent(POST_ACT_STAGE, env)
rl_params.hooks[i](POST_ACT_STAGE, agent, env)
end
return nothing
end
2021-12-12 14:29:08 +00:00
2021-12-12 23:19:18 +00:00
function run(;
2021-12-12 14:29:08 +00:00
n_episodes::Int64=100,
episode_duration::Float64=5.0,
update_actions_at::Float64=0.1,
2021-12-12 17:27:56 +00:00
n_particles::Int64=10,
2021-12-12 14:29:08 +00:00
)
@assert n_episodes > 0
@assert episode_duration > 0
@assert update_actions_at in 0.01:0.01:episode_duration
@assert n_particles > 0
Random.seed!(42)
2021-12-10 02:16:45 +00:00
2021-12-12 23:19:18 +00:00
sim_consts = ReCo.gen_sim_consts(n_particles, 0.0; skin_to_interaction_r_ratio=3.5)
n_particles = sim_consts.n_particles
2021-12-12 14:29:08 +00:00
env_params = EnvParams(sim_consts.particle_radius, sim_consts.skin_r)
2021-12-12 23:19:18 +00:00
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
2021-12-12 17:27:56 +00:00
rl_params = Params{TotalRewardPerEpisode}(
n_particles, env_params, n_steps_before_actions_update
)
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](PRE_EXPERIMENT_STAGE, agent, env)
agent(PRE_EXPERIMENT_STAGE, env)
end
2021-12-12 14:29:08 +00:00
for episode in 1:n_episodes
2021-12-12 23:19:18 +00:00
dir, particles = ReCo.init_sim_with_sim_consts(sim_consts; parent_dir="RL")
2021-12-12 17:27:56 +00:00
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](PRE_EPISODE_STAGE, agent, env)
agent(PRE_EPISODE_STAGE, env)
end
for i in 1:n_particles
rl_params.envs[i].particle = particles[i]
rl_params.envs[i].state = (nothing, nothing)
end
2021-12-12 14:29:08 +00:00
2021-12-12 17:27:56 +00:00
rl_params.env_params.reward = 0.0
2021-12-12 14:29:08 +00:00
run_sim(
2021-12-12 17:27:56 +00:00
dir; duration=episode_duration, seed=rand(1:typemax(Int64)), rl_params=rl_params
2021-12-12 14:29:08 +00:00
)
2021-12-12 17:27:56 +00:00
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](POST_EPISODE_STAGE, agent, env)
agent(POST_EPISODE_STAGE, env)
end
2021-12-12 14:29:08 +00:00
end
2021-12-12 17:27:56 +00:00
for i in 1:n_particles
env = rl_params.envs[i]
agent = rl_params.agents[i]
rl_params.hooks[i](POST_EXPERIMENT_STAGE, agent, env)
end
return rl_params.hooks
2021-12-12 14:29:08 +00:00
end
2021-12-10 02:16:45 +00:00
2021-12-12 14:29:08 +00:00
end # module