mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-11-08 22:21:08 +00:00
Merge branch 'compass_to_center_of_mass'
This commit is contained in:
commit
275b69c928
3 changed files with 232 additions and 133 deletions
|
@ -26,13 +26,16 @@ function animate_bundle!(args, sim_consts)
|
||||||
color = get(args.color_scheme, rem2pi(bundle_φ[i, frame], RoundDown) / π2)
|
color = get(args.color_scheme, rem2pi(bundle_φ[i, frame], RoundDown) / π2)
|
||||||
args.colors[][i] = RGBAf(color)
|
args.colors[][i] = RGBAf(color)
|
||||||
|
|
||||||
if args.debug
|
if args.show_interaction_circle
|
||||||
args.interaction_circles[][i] = Circle(
|
args.interaction_circles[][i] = Circle(
|
||||||
Point2(c[1], c[2]), sim_consts.interaction_r
|
Point2(c[1], c[2]), sim_consts.interaction_r
|
||||||
)
|
)
|
||||||
|
args.interaction_colors[][i] = RGBAf(color, 0.08)
|
||||||
|
end
|
||||||
|
|
||||||
|
if args.show_skin_circle
|
||||||
args.skin_circles[][i] = Circle(Point2(c[1], c[2]), sim_consts.skin_r)
|
args.skin_circles[][i] = Circle(Point2(c[1], c[2]), sim_consts.skin_r)
|
||||||
|
|
||||||
args.interaction_colors[][i] = RGBAf(color, 0.08)
|
|
||||||
args.skin_colors[][i] = RGBAf(color, 0.04)
|
args.skin_colors[][i] = RGBAf(color, 0.04)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -47,21 +50,24 @@ function animate_bundle!(args, sim_consts)
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
if args.n_bundle == 1
|
if args.n_bundle == 1 # First and only frame of first bundle
|
||||||
poly!(args.ax, args.circles; color=args.colors)
|
poly!(args.ax, args.circles; color=args.colors)
|
||||||
|
|
||||||
if args.show_center_of_mass
|
if args.show_center_of_mass
|
||||||
poly!(args.ax, args.center_of_mass_circle; color=RGBAf(1, 1, 1, 0.5))
|
poly!(args.ax, args.center_of_mass_circle; color=RGBAf(1, 1, 1, 0.5))
|
||||||
end
|
end
|
||||||
|
|
||||||
if args.debug
|
if args.show_interaction_circle
|
||||||
poly!(args.ax, args.interaction_circles; color=args.interaction_colors)
|
poly!(args.ax, args.interaction_circles; color=args.interaction_colors)
|
||||||
|
end
|
||||||
|
|
||||||
|
if args.show_skin_circle
|
||||||
poly!(args.ax, args.skin_circles; color=args.skin_colors)
|
poly!(args.ax, args.skin_circles; color=args.skin_colors)
|
||||||
end
|
end
|
||||||
|
|
||||||
println("Recording started!")
|
println("Recording started!")
|
||||||
else
|
else
|
||||||
if args.debug && frame > 1
|
if args.show_frame_diff && frame > 1
|
||||||
@simd for i in 1:(sim_consts.n_particles)
|
@simd for i in 1:(sim_consts.n_particles)
|
||||||
first_ind = 2 * i - 1
|
first_ind = 2 * i - 1
|
||||||
second_ind = 2 * i
|
second_ind = 2 * i
|
||||||
|
@ -79,6 +85,9 @@ function animate_bundle!(args, sim_consts)
|
||||||
args.ax, args.segments_x, args.segments_y; color=args.colors
|
args.ax, args.segments_x, args.segments_y; color=args.colors
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
notify(args.segments_x)
|
||||||
|
notify(args.segments_y)
|
||||||
end
|
end
|
||||||
|
|
||||||
notify(args.circles)
|
notify(args.circles)
|
||||||
|
@ -88,15 +97,14 @@ function animate_bundle!(args, sim_consts)
|
||||||
notify(args.center_of_mass_circle)
|
notify(args.center_of_mass_circle)
|
||||||
end
|
end
|
||||||
|
|
||||||
if args.debug && frame > 1
|
if args.show_interaction_circle
|
||||||
notify(args.interaction_circles)
|
notify(args.interaction_circles)
|
||||||
notify(args.interaction_colors)
|
notify(args.interaction_colors)
|
||||||
|
end
|
||||||
|
|
||||||
|
if args.show_skin_circle
|
||||||
notify(args.skin_circles)
|
notify(args.skin_circles)
|
||||||
notify(args.skin_colors)
|
notify(args.skin_colors)
|
||||||
|
|
||||||
notify(args.segments_x)
|
|
||||||
notify(args.segments_y)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -108,8 +116,32 @@ function animate_bundle!(args, sim_consts)
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function sort_bundle_paths(bundle_paths::Vector{String})
|
||||||
|
n_bundles = length(bundle_paths)
|
||||||
|
|
||||||
|
bundle_nums = Vector{Int64}(undef, n_bundles)
|
||||||
|
|
||||||
|
extension_length = 5 # == length(".jld2")
|
||||||
|
|
||||||
|
for i in 1:n_bundles
|
||||||
|
bundle_path = bundle_paths[i]
|
||||||
|
bundle_num_string = bundle_path[(findfirst("bundle_", bundle_path).stop + 1):(end - extension_length)]
|
||||||
|
bundle_nums[i] = parse(Int64, bundle_num_string)
|
||||||
|
end
|
||||||
|
|
||||||
|
sort_perm = sortperm(bundle_nums)
|
||||||
|
|
||||||
|
return bundle_paths[sort_perm]
|
||||||
|
end
|
||||||
|
|
||||||
function animate_with_sim_consts(
|
function animate_with_sim_consts(
|
||||||
dir::String, sim_consts, framerate::Int64, show_center_of_mass::Bool, debug::Bool
|
dir::String,
|
||||||
|
sim_consts,
|
||||||
|
framerate::Int64,
|
||||||
|
show_center_of_mass::Bool,
|
||||||
|
show_interaction_circle::Bool,
|
||||||
|
show_skin_circle::Bool,
|
||||||
|
show_frame_diff::Bool,
|
||||||
)
|
)
|
||||||
set_theme!(theme_black())
|
set_theme!(theme_black())
|
||||||
|
|
||||||
|
@ -153,26 +185,24 @@ function animate_with_sim_consts(
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
if debug
|
if show_interaction_circle
|
||||||
segments_x = Observable(zeros(2 * n_particles))
|
|
||||||
segments_y = Observable(zeros(2 * n_particles))
|
|
||||||
|
|
||||||
interaction_circles = Observable(Vector{Circle}(undef, n_particles))
|
interaction_circles = Observable(Vector{Circle}(undef, n_particles))
|
||||||
skin_circles = Observable(Vector{Circle}(undef, n_particles))
|
skin_circles = Observable(Vector{Circle}(undef, n_particles))
|
||||||
|
end
|
||||||
|
|
||||||
|
if show_skin_circle
|
||||||
interaction_colors = Observable(Vector{RGBAf}(undef, n_particles))
|
interaction_colors = Observable(Vector{RGBAf}(undef, n_particles))
|
||||||
skin_colors = Observable(Vector{RGBAf}(undef, n_particles))
|
skin_colors = Observable(Vector{RGBAf}(undef, n_particles))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
if show_frame_diff
|
||||||
|
segments_x = Observable(zeros(2 * n_particles))
|
||||||
|
segments_y = Observable(zeros(2 * n_particles))
|
||||||
|
end
|
||||||
|
|
||||||
bundle_paths = readdir("$dir/bundles"; join=true, sort=false)
|
bundle_paths = readdir("$dir/bundles"; join=true, sort=false)
|
||||||
|
|
||||||
sort_perm = sortperm([
|
bundle_paths = sort_bundle_paths(bundle_paths)
|
||||||
parse(Int64, s[(findfirst("bundle_", s).stop + 1):(end - length(".jld2"))]) for
|
|
||||||
s in bundle_paths
|
|
||||||
])
|
|
||||||
|
|
||||||
bundle_paths = bundle_paths[sort_perm]
|
|
||||||
sort_perm = nothing
|
|
||||||
|
|
||||||
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
@showprogress 1 for (n_bundle, bundle_path) in enumerate(bundle_paths)
|
||||||
bundle::Bundle = JLD2.load_object(bundle_path)
|
bundle::Bundle = JLD2.load_object(bundle_path)
|
||||||
|
@ -180,7 +210,9 @@ function animate_with_sim_consts(
|
||||||
args = (;
|
args = (;
|
||||||
# Input
|
# Input
|
||||||
show_center_of_mass,
|
show_center_of_mass,
|
||||||
debug,
|
show_interaction_circle,
|
||||||
|
show_skin_circle,
|
||||||
|
show_frame_diff,
|
||||||
# Intern
|
# Intern
|
||||||
io,
|
io,
|
||||||
ax,
|
ax,
|
||||||
|
@ -207,13 +239,26 @@ function animate_with_sim_consts(
|
||||||
end
|
end
|
||||||
|
|
||||||
function animate(
|
function animate(
|
||||||
dir::String; framerate::Int64=1, show_center_of_mass::Bool=false, debug::Bool=false
|
dir::String;
|
||||||
|
framerate::Int64=1,
|
||||||
|
show_center_of_mass::Bool=false,
|
||||||
|
show_interaction_circle::Bool=false,
|
||||||
|
show_skin_circle::Bool=false,
|
||||||
|
show_frame_diff::Bool=false,
|
||||||
)
|
)
|
||||||
println("Generating animation...")
|
println("Generating animation...")
|
||||||
|
|
||||||
sim_consts = JSON3.read(read("$dir/sim_consts.json", String))
|
sim_consts = JSON3.read(read("$dir/sim_consts.json", String))
|
||||||
|
|
||||||
animate_with_sim_consts(dir, sim_consts, framerate, show_center_of_mass, debug)
|
animate_with_sim_consts(
|
||||||
|
dir,
|
||||||
|
sim_consts,
|
||||||
|
framerate,
|
||||||
|
show_center_of_mass,
|
||||||
|
show_interaction_circle,
|
||||||
|
show_skin_circle,
|
||||||
|
show_frame_diff,
|
||||||
|
)
|
||||||
|
|
||||||
println("Animation done.")
|
println("Animation done.")
|
||||||
|
|
||||||
|
|
316
src/RL.jl
316
src/RL.jl
|
@ -2,6 +2,8 @@ module RL
|
||||||
|
|
||||||
export run_rl
|
export run_rl
|
||||||
|
|
||||||
|
using Base: OneTo
|
||||||
|
|
||||||
using ReinforcementLearning
|
using ReinforcementLearning
|
||||||
using Flux: InvDecay
|
using Flux: InvDecay
|
||||||
using Intervals
|
using Intervals
|
||||||
|
@ -10,80 +12,13 @@ using LoopVectorization: @turbo
|
||||||
using Random: Random
|
using Random: Random
|
||||||
using ProgressMeter: @showprogress
|
using ProgressMeter: @showprogress
|
||||||
|
|
||||||
using ..ReCo: ReCo, Particle, angle2, center_of_mass
|
using ..ReCo: ReCo, Particle, angle2, Shape
|
||||||
|
|
||||||
const INITIAL_REWARD = 0.0
|
const INITIAL_REWARD = 0.0
|
||||||
|
const INITIAL_STATE_IND = 1
|
||||||
|
|
||||||
mutable struct Env <: AbstractEnv
|
function angle_state_space(n_angle_states::Int64)
|
||||||
n_actions::Int64
|
angle_range = range(; start=-π, stop=π, length=n_angle_states + 1)
|
||||||
action_space::Vector{SVector{2,Float64}}
|
|
||||||
action_ind_space::Vector{Int64}
|
|
||||||
|
|
||||||
distance_state_space::Vector{Interval}
|
|
||||||
angle_state_space::Vector{Interval}
|
|
||||||
|
|
||||||
n_states::Int64
|
|
||||||
state_space::Vector{SVector{2,Interval}}
|
|
||||||
state_ind_space::Vector{Int64}
|
|
||||||
state_ind::Int64
|
|
||||||
|
|
||||||
reward::Float64
|
|
||||||
terminated::Bool
|
|
||||||
|
|
||||||
center_of_mass::SVector{2,Float64}
|
|
||||||
|
|
||||||
function Env(
|
|
||||||
max_distance::Float64;
|
|
||||||
min_distance::Float64=0.0,
|
|
||||||
n_v_actions::Int64=3,
|
|
||||||
n_ω_actions::Int64=3,
|
|
||||||
max_v::Float64=40.0,
|
|
||||||
max_ω::Float64=π / 2,
|
|
||||||
n_distance_states::Int64=3,
|
|
||||||
n_angle_states::Int64=4,
|
|
||||||
)
|
|
||||||
@assert min_distance >= 0.0
|
|
||||||
@assert max_distance > min_distance
|
|
||||||
@assert n_v_actions > 1
|
|
||||||
@assert n_ω_actions > 1
|
|
||||||
@assert max_v > 0
|
|
||||||
@assert max_ω > 0
|
|
||||||
|
|
||||||
v_action_space = 0.0:(max_v / (n_v_actions - 1)):max_v
|
|
||||||
ω_action_space = (-max_ω):(2 * max_ω / (n_ω_actions - 1)):max_ω
|
|
||||||
|
|
||||||
n_actions = n_v_actions * n_ω_actions
|
|
||||||
|
|
||||||
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
|
|
||||||
|
|
||||||
ind = 1
|
|
||||||
for v in v_action_space
|
|
||||||
for ω in ω_action_space
|
|
||||||
action_space[ind] = SVector(v, ω)
|
|
||||||
ind += 1
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
action_ind_space = collect(1:n_actions)
|
|
||||||
|
|
||||||
distance_range =
|
|
||||||
min_distance:((max_distance - min_distance) / n_distance_states):max_distance
|
|
||||||
|
|
||||||
distance_state_space = Vector{Interval}(undef, n_distance_states)
|
|
||||||
|
|
||||||
@simd for i in 1:n_distance_states
|
|
||||||
if i == 1
|
|
||||||
bound = Closed
|
|
||||||
else
|
|
||||||
bound = Open
|
|
||||||
end
|
|
||||||
|
|
||||||
distance_state_space[i] = Interval{Float64,bound,Closed}(
|
|
||||||
distance_range[i], distance_range[i + 1]
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
angle_range = (-π):(2 * π / n_angle_states):π
|
|
||||||
|
|
||||||
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
angle_state_space = Vector{Interval}(undef, n_angle_states)
|
||||||
|
|
||||||
|
@ -99,34 +34,118 @@ mutable struct Env <: AbstractEnv
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
n_states = n_distance_states * n_angle_states + 1
|
return angle_state_space
|
||||||
|
end
|
||||||
|
|
||||||
state_space = Vector{SVector{2,Interval}}(undef, n_states - 1)
|
mutable struct Env <: AbstractEnv
|
||||||
|
n_actions::Int64
|
||||||
|
action_space::Vector{SVector{2,Float64}}
|
||||||
|
action_ind_space::OneTo{Int64}
|
||||||
|
|
||||||
|
distance_state_space::Vector{Interval}
|
||||||
|
direction_angle_state_space::Vector{Interval}
|
||||||
|
position_angle_state_space::Vector{Interval}
|
||||||
|
|
||||||
|
n_states::Int64
|
||||||
|
state_space::Vector{SVector{3,Interval}}
|
||||||
|
state_ind_space::OneTo{Int64}
|
||||||
|
state_ind::Int64
|
||||||
|
|
||||||
|
reward::Float64
|
||||||
|
terminated::Bool
|
||||||
|
|
||||||
|
center_of_mass::SVector{2,Float64} # TODO: Use or remove
|
||||||
|
gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64}
|
||||||
|
gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64}
|
||||||
|
|
||||||
|
function Env(;
|
||||||
|
max_distance::Float64,
|
||||||
|
min_distance::Float64=0.0,
|
||||||
|
n_v_actions::Int64=2,
|
||||||
|
n_ω_actions::Int64=3,
|
||||||
|
max_v::Float64=40.0,
|
||||||
|
max_ω::Float64=π / 2,
|
||||||
|
n_distance_states::Int64=4,
|
||||||
|
n_direction_angle_states::Int64=3,
|
||||||
|
n_position_angle_states::Int64=8,
|
||||||
|
)
|
||||||
|
@assert min_distance >= 0.0
|
||||||
|
@assert max_distance > min_distance
|
||||||
|
@assert n_v_actions > 1
|
||||||
|
@assert n_ω_actions > 1
|
||||||
|
@assert max_v > 0
|
||||||
|
@assert max_ω > 0
|
||||||
|
@assert n_distance_states > 1
|
||||||
|
@assert n_direction_angle_states > 1
|
||||||
|
@assert n_position_angle_states > 1
|
||||||
|
|
||||||
|
v_action_space = range(; start=0.0, stop=max_v, length=n_v_actions)
|
||||||
|
ω_action_space = range(; start=-max_ω, stop=max_ω, length=n_ω_actions)
|
||||||
|
|
||||||
|
n_actions = n_v_actions * n_ω_actions
|
||||||
|
|
||||||
|
action_space = Vector{SVector{2,Float64}}(undef, n_actions)
|
||||||
|
|
||||||
ind = 1
|
ind = 1
|
||||||
for distance_state in distance_state_space
|
for v in v_action_space
|
||||||
for angle_state in angle_state_space
|
for ω in ω_action_space
|
||||||
state_space[ind] = SVector(distance_state, angle_state)
|
action_space[ind] = SVector(v, ω)
|
||||||
ind += 1
|
ind += 1
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
# Last state is SVector(nothing, nothing)
|
|
||||||
|
|
||||||
state_ind_space = collect(1:n_states)
|
action_ind_space = OneTo(n_actions)
|
||||||
|
|
||||||
# initial_state = SVector(nothing, nothing)
|
distance_range = range(;
|
||||||
initial_state_ind = n_states
|
start=min_distance, stop=max_distance, length=n_distance_states + 1
|
||||||
|
)
|
||||||
|
|
||||||
|
distance_state_space = Vector{Interval}(undef, n_distance_states)
|
||||||
|
|
||||||
|
@simd for i in 1:n_distance_states
|
||||||
|
if i == 1
|
||||||
|
bound = Closed
|
||||||
|
else
|
||||||
|
bound = Open
|
||||||
|
end
|
||||||
|
|
||||||
|
distance_state_space[i] = Interval{Float64,bound,Closed}(
|
||||||
|
distance_range[i], distance_range[i + 1]
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
direction_angle_state_space = angle_state_space(n_direction_angle_states)
|
||||||
|
position_angle_state_space = angle_state_space(n_position_angle_states)
|
||||||
|
|
||||||
|
n_states = n_distance_states * n_direction_angle_states * n_position_angle_states
|
||||||
|
|
||||||
|
state_space = Vector{SVector{3,Interval}}(undef, n_states)
|
||||||
|
|
||||||
|
ind = 1
|
||||||
|
for distance_state in distance_state_space
|
||||||
|
for direction_angle_state in direction_angle_state_space
|
||||||
|
for position_angle_state in position_angle_state_space
|
||||||
|
state_space[ind] = SVector(
|
||||||
|
distance_state, direction_angle_state, position_angle_state
|
||||||
|
)
|
||||||
|
ind += 1
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
state_ind_space = OneTo(n_states)
|
||||||
|
|
||||||
return new(
|
return new(
|
||||||
n_actions,
|
n_actions,
|
||||||
action_space,
|
action_space,
|
||||||
action_ind_space,
|
action_ind_space,
|
||||||
distance_state_space,
|
distance_state_space,
|
||||||
angle_state_space,
|
direction_angle_state_space,
|
||||||
|
position_angle_state_space,
|
||||||
n_states,
|
n_states,
|
||||||
state_space,
|
state_space,
|
||||||
state_ind_space,
|
state_ind_space,
|
||||||
initial_state_ind,
|
INITIAL_STATE_IND,
|
||||||
INITIAL_REWARD,
|
INITIAL_REWARD,
|
||||||
false,
|
false,
|
||||||
SVector(0.0, 0.0),
|
SVector(0.0, 0.0),
|
||||||
|
@ -165,21 +184,18 @@ struct Params{H<:AbstractHook}
|
||||||
|
|
||||||
n_steps_before_actions_update::Int64
|
n_steps_before_actions_update::Int64
|
||||||
|
|
||||||
goal_shape_ratio::Float64
|
goal_gyration_tensor_eigvals_ratio::Float64
|
||||||
|
|
||||||
n_particles::Int64
|
n_particles::Int64
|
||||||
half_box_len::Float64
|
half_box_len::Float64
|
||||||
max_elliptic_distance::Float64
|
max_elliptic_distance::Float64
|
||||||
|
|
||||||
local_centers_of_mass::Vector{SVector{2,Float64}}
|
|
||||||
updated_local_center_of_mass::Vector{Bool}
|
|
||||||
|
|
||||||
function Params(
|
function Params(
|
||||||
env::Env,
|
env::Env,
|
||||||
agent::Agent,
|
agent::Agent,
|
||||||
hook::H,
|
hook::H,
|
||||||
n_steps_before_actions_update::Int64,
|
n_steps_before_actions_update::Int64,
|
||||||
goal_shape_ratio::Float64,
|
goal_gyration_tensor_eigvals_ratio::Float64,
|
||||||
n_particles::Int64,
|
n_particles::Int64,
|
||||||
half_box_len::Float64,
|
half_box_len::Float64,
|
||||||
) where {H<:AbstractHook}
|
) where {H<:AbstractHook}
|
||||||
|
@ -196,41 +212,36 @@ struct Params{H<:AbstractHook}
|
||||||
fill(SVector(0.0, 0.0), n_particles),
|
fill(SVector(0.0, 0.0), n_particles),
|
||||||
fill(0, n_particles),
|
fill(0, n_particles),
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
goal_shape_ratio,
|
goal_gyration_tensor_eigvals_ratio,
|
||||||
n_particles,
|
n_particles,
|
||||||
half_box_len,
|
half_box_len,
|
||||||
max_elliptic_distance,
|
max_elliptic_distance,
|
||||||
fill(SVector(0.0, 0.0), n_particles),
|
|
||||||
falses(n_particles),
|
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function pre_integration_hook(rl_params::Params)
|
function pre_integration_hook(rl_params::Params)
|
||||||
@simd for id in 1:(rl_params.n_particles)
|
|
||||||
rl_params.local_centers_of_mass[id] = SVector(0.0, 0.0)
|
|
||||||
rl_params.updated_local_center_of_mass[id] = false
|
|
||||||
end
|
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function state_update_helper_hook(
|
function state_update_helper_hook(
|
||||||
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
rl_params::Params, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64}
|
||||||
)
|
)
|
||||||
rl_params.local_centers_of_mass[id1] += r⃗₁₂
|
|
||||||
rl_params.local_centers_of_mass[id2] -= r⃗₁₂
|
|
||||||
|
|
||||||
rl_params.updated_local_center_of_mass[id1] = true
|
|
||||||
rl_params.updated_local_center_of_mass[id2] = true
|
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function get_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{2,Interval}}
|
function find_state_ind(state::S, state_space::Vector{S}) where {S<:SVector{3,Interval}}
|
||||||
return findfirst(x -> x == state, state_space)
|
return findfirst(x -> x == state, state_space)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function find_state_interval(value::Float64, state_space::Vector{Interval})::Interval
|
||||||
|
for state in state_space
|
||||||
|
if value in state
|
||||||
|
return state
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
||||||
@turbo for id in 1:(rl_params.n_particles)
|
@turbo for id in 1:(rl_params.n_particles)
|
||||||
rl_params.old_states_ind[id] = rl_params.states_ind[id]
|
rl_params.old_states_ind[id] = rl_params.states_ind[id]
|
||||||
|
@ -238,44 +249,43 @@ function state_update_hook(rl_params::Params, particles::Vector{Particle})
|
||||||
|
|
||||||
env = rl_params.env
|
env = rl_params.env
|
||||||
|
|
||||||
env_distance_state = env.distance_state_space[1]
|
env.center_of_mass = Shape.center_of_mass(particles, rl_params.half_box_len)
|
||||||
env_angle_state = env.angle_state_space[1]
|
|
||||||
state_ind = 0
|
|
||||||
|
|
||||||
for id in 1:(rl_params.n_particles)
|
for id in 1:(rl_params.n_particles)
|
||||||
if !rl_params.updated_local_center_of_mass[id]
|
particle = particles[id]
|
||||||
state_ind = env.n_states
|
|
||||||
else
|
|
||||||
local_center_of_mass = rl_params.local_centers_of_mass[id]
|
|
||||||
|
|
||||||
distance = sqrt(local_center_of_mass[1]^2 + local_center_of_mass[2]^2)
|
vec_to_center_of_mass = ReCo.minimum_image(
|
||||||
|
env.center_of_mass - particle.c, rl_params.half_box_len
|
||||||
|
)
|
||||||
|
|
||||||
for distance_state in env.distance_state_space
|
distance = sqrt(vec_to_center_of_mass[1]^2 + vec_to_center_of_mass[2]^2)
|
||||||
if distance in distance_state
|
|
||||||
env_distance_state = distance_state
|
distance_state = find_state_interval(distance, env.distance_state_space)
|
||||||
break
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
si, co = sincos(particles[id].φ)
|
si, co = sincos(particles[id].φ)
|
||||||
|
|
||||||
angle = angle2(SVector(co, si), local_center_of_mass)
|
direction_angle = angle2(SVector(co, si), vec_to_center_of_mass)
|
||||||
|
position_angle = atan(-vec_to_center_of_mass[2], -vec_to_center_of_mass[1])
|
||||||
|
|
||||||
for angle_state in env.angle_state_space
|
direction_angle_state = find_state_interval(
|
||||||
if angle in angle_state
|
direction_angle, env.direction_angle_state_space
|
||||||
env_angle_state = angle_state
|
)
|
||||||
break
|
position_angle_state = find_state_interval(
|
||||||
end
|
position_angle, env.position_angle_state_space
|
||||||
end
|
)
|
||||||
|
|
||||||
state = SVector{2,Interval}(env_distance_state, env_angle_state)
|
state = SVector{3,Interval}(
|
||||||
state_ind = get_state_ind(state, env.state_space)
|
distance_state, direction_angle_state, position_angle_state
|
||||||
end
|
)
|
||||||
|
state_ind = find_state_ind(state, env.state_space)
|
||||||
|
|
||||||
rl_params.states_ind[id] = state_ind
|
rl_params.states_ind[id] = state_ind
|
||||||
end
|
end
|
||||||
|
|
||||||
env.center_of_mass = center_of_mass(particles, rl_params.half_box_len)
|
v1, v2 = Shape.gyration_tensor_eigvecs(particles, rl_params.half_box_len) # TODO: Reuse center_of_mass
|
||||||
|
|
||||||
|
env.gyration_tensor_eigvec_to_smaller_eigval = v1
|
||||||
|
env.gyration_tensor_eigvec_to_bigger_eigval = v2
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
@ -284,6 +294,18 @@ function get_env_agent_hook(rl_params::Params)
|
||||||
return (rl_params.env, rl_params.agent, rl_params.hook)
|
return (rl_params.env, rl_params.agent, rl_params.hook)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function update_reward!(env::Env, rl_params::Params, particle::Particle)
|
||||||
|
env.reward =
|
||||||
|
-Shape.elliptical_distance(
|
||||||
|
particle,
|
||||||
|
env.gyration_tensor_eigvec_to_smaller_eigval,
|
||||||
|
env.gyration_tensor_eigvec_to_bigger_eigval,
|
||||||
|
rl_params.goal_gyration_tensor_eigvals_ratio,
|
||||||
|
) / (rl_params.max_elliptic_distance^2 * rl_params.n_particles)
|
||||||
|
|
||||||
|
return nothing
|
||||||
|
end
|
||||||
|
|
||||||
function update_table_and_actions_hook(
|
function update_table_and_actions_hook(
|
||||||
rl_params::Params, particle::Particle, first_integration_step::Bool
|
rl_params::Params, particle::Particle, first_integration_step::Bool
|
||||||
)
|
)
|
||||||
|
@ -305,13 +327,7 @@ function update_table_and_actions_hook(
|
||||||
env.state_ind = rl_params.states_ind[id]
|
env.state_ind = rl_params.states_ind[id]
|
||||||
|
|
||||||
# Update reward
|
# Update reward
|
||||||
vec_to_center_of_mass = ReCo.minimum_image(
|
update_reward!(env, rl_params, particle)
|
||||||
particle.c - env.center_of_mass, rl_params.half_box_len
|
|
||||||
)
|
|
||||||
|
|
||||||
env.reward =
|
|
||||||
-(vec_to_center_of_mass[1]^2 + vec_to_center_of_mass[2]^2) /
|
|
||||||
rl_params.max_elliptic_distance / rl_params.n_particles
|
|
||||||
|
|
||||||
# Post act
|
# Post act
|
||||||
agent(POST_ACT_STAGE, env)
|
agent(POST_ACT_STAGE, env)
|
||||||
|
@ -343,47 +359,57 @@ function act_hook(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function gen_agent(n_states::Int64, n_actions::Int64, ϵ::Float64)
|
function gen_agent(n_states::Int64, n_actions::Int64, ϵ_stable::Float64)
|
||||||
|
# TODO: Optimize warmup and decay
|
||||||
|
warmup_steps = 200_000
|
||||||
|
decay_steps = 1_000_000
|
||||||
|
|
||||||
policy = QBasedPolicy(;
|
policy = QBasedPolicy(;
|
||||||
learner=MonteCarloLearner(;
|
learner=MonteCarloLearner(;
|
||||||
approximator=TabularQApproximator(;
|
approximator=TabularQApproximator(;
|
||||||
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
n_state=n_states, n_action=n_actions, opt=InvDecay(1.0)
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
explorer=EpsilonGreedyExplorer(ϵ),
|
explorer=EpsilonGreedyExplorer(;
|
||||||
|
kind=:linear,
|
||||||
|
ϵ_init=1.0,
|
||||||
|
ϵ_stable=ϵ_stable,
|
||||||
|
warmup_steps=warmup_steps,
|
||||||
|
decay_steps=decay_steps,
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
return Agent(; policy=policy, trajectory=VectorSARTTrajectory())
|
return Agent(; policy=policy, trajectory=VectorSARTTrajectory())
|
||||||
end
|
end
|
||||||
|
|
||||||
function run_rl(;
|
function run_rl(;
|
||||||
goal_shape_ratio::Float64,
|
goal_gyration_tensor_eigvals_ratio::Float64,
|
||||||
n_episodes::Int64=200,
|
n_episodes::Int64=200,
|
||||||
episode_duration::Float64=50.0,
|
episode_duration::Float64=50.0,
|
||||||
update_actions_at::Float64=0.1,
|
update_actions_at::Float64=0.1,
|
||||||
n_particles::Int64=100,
|
n_particles::Int64=100,
|
||||||
seed::Int64=42,
|
seed::Int64=42,
|
||||||
ϵ::Float64=0.01,
|
ϵ_stable::Float64=0.0001,
|
||||||
parent_dir::String="",
|
parent_dir::String="",
|
||||||
)
|
)
|
||||||
@assert 0.0 <= goal_shape_ratio <= 1.0
|
@assert 0.0 <= goal_gyration_tensor_eigvals_ratio <= 1.0
|
||||||
@assert n_episodes > 0
|
@assert n_episodes > 0
|
||||||
@assert episode_duration > 0
|
@assert episode_duration > 0
|
||||||
@assert update_actions_at in 0.001:0.001:episode_duration
|
@assert update_actions_at in 0.001:0.001:episode_duration
|
||||||
@assert n_particles > 0
|
@assert n_particles > 0
|
||||||
@assert 0.0 < ϵ < 1.0
|
@assert 0.0 < ϵ_stable < 1.0
|
||||||
|
|
||||||
# Setup
|
# Setup
|
||||||
Random.seed!(seed)
|
Random.seed!(seed)
|
||||||
|
|
||||||
sim_consts = ReCo.gen_sim_consts(
|
sim_consts = ReCo.gen_sim_consts(
|
||||||
n_particles, 0.0; skin_to_interaction_r_ratio=1.8, packing_ratio=0.15
|
n_particles, 0.0; skin_to_interaction_r_ratio=1.5, packing_ratio=0.22
|
||||||
)
|
)
|
||||||
n_particles = sim_consts.n_particles
|
n_particles = sim_consts.n_particles
|
||||||
|
|
||||||
env = Env(sim_consts.skin_r)
|
env = Env(; max_distance=sqrt(2) * sim_consts.half_box_len)
|
||||||
|
|
||||||
agent = gen_agent(env.n_states, env.n_actions, ϵ)
|
agent = gen_agent(env.n_states, env.n_actions, ϵ_stable)
|
||||||
|
|
||||||
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
n_steps_before_actions_update = round(Int64, update_actions_at / sim_consts.δt)
|
||||||
|
|
||||||
|
@ -394,7 +420,7 @@ function run_rl(;
|
||||||
agent,
|
agent,
|
||||||
hook,
|
hook,
|
||||||
n_steps_before_actions_update,
|
n_steps_before_actions_update,
|
||||||
goal_shape_ratio,
|
goal_gyration_tensor_eigvals_ratio,
|
||||||
n_particles,
|
n_particles,
|
||||||
sim_consts.half_box_len,
|
sim_consts.half_box_len,
|
||||||
)
|
)
|
||||||
|
@ -426,7 +452,9 @@ function run_rl(;
|
||||||
hook(POST_EPISODE_STAGE, agent, env)
|
hook(POST_EPISODE_STAGE, agent, env)
|
||||||
agent(POST_EPISODE_STAGE, env)
|
agent(POST_EPISODE_STAGE, env)
|
||||||
|
|
||||||
|
# TODO: Replace with live plot
|
||||||
display(hook.rewards)
|
display(hook.rewards)
|
||||||
|
display(agent.policy.explorer.step)
|
||||||
end
|
end
|
||||||
|
|
||||||
# Post experiment
|
# Post experiment
|
||||||
|
|
32
src/Shape.jl
32
src/Shape.jl
|
@ -1,20 +1,23 @@
|
||||||
module Shape
|
module Shape
|
||||||
|
|
||||||
export center_of_mass, gyration_tensor_eigvals_ratio
|
export center_of_mass,
|
||||||
|
gyration_tensor_eigvals_ratio, gyration_tensor_eigvecs, elliptical_distance
|
||||||
|
|
||||||
using StaticArrays: SVector, SMatrix
|
using StaticArrays: SVector, SMatrix
|
||||||
using LinearAlgebra: eigvals, Hermitian
|
using LinearAlgebra: eigvals, eigvecs, Hermitian, dot
|
||||||
|
|
||||||
using ..ReCo: Particle, restrict_coordinate, restrict_coordinates
|
using ..ReCo: Particle, restrict_coordinate, restrict_coordinates
|
||||||
|
|
||||||
function project_to_unit_circle(x::Float64, half_box_len::Float64)
|
function project_to_unit_circle(x::Float64, half_box_len::Float64)
|
||||||
φ = (x + half_box_len) * π / half_box_len
|
φ = (x + half_box_len) * π / half_box_len
|
||||||
si, co = sincos(φ)
|
si, co = sincos(φ)
|
||||||
|
|
||||||
return SVector(co, si)
|
return SVector(co, si)
|
||||||
end
|
end
|
||||||
|
|
||||||
function project_back_from_unit_circle(θ::T, half_box_len::Float64) where {T<:Real}
|
function project_back_from_unit_circle(θ::T, half_box_len::Float64) where {T<:Real}
|
||||||
x = θ * half_box_len / π - half_box_len
|
x = θ * half_box_len / π - half_box_len
|
||||||
|
|
||||||
return restrict_coordinate(x, half_box_len)
|
return restrict_coordinate(x, half_box_len)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -87,8 +90,31 @@ function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64)
|
||||||
end
|
end
|
||||||
|
|
||||||
function gyration_tensor_eigvals_ratio(particles::Vector{Particle}, half_box_len::Float64)
|
function gyration_tensor_eigvals_ratio(particles::Vector{Particle}, half_box_len::Float64)
|
||||||
ev = eigvals(gyration_tensor(particles, half_box_len)) # Eigenvalues are sorted
|
g_tensor = gyration_tensor(particles, half_box_len)
|
||||||
|
ev = eigvals(g_tensor) # Eigenvalues are sorted
|
||||||
return ev[1] / ev[2]
|
return ev[1] / ev[2]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
function gyration_tensor_eigvecs(particles::Vector{Particle}, half_box_len::Float64)
|
||||||
|
g_tensor = gyration_tensor(particles, half_box_len)
|
||||||
|
eig_vecs = eigvecs(g_tensor)
|
||||||
|
|
||||||
|
v1 = eig_vecs[:, 1]
|
||||||
|
v2 = eig_vecs[:, 2]
|
||||||
|
|
||||||
|
return (v1, v2)
|
||||||
|
end
|
||||||
|
|
||||||
|
function elliptical_distance(
|
||||||
|
particle::Particle,
|
||||||
|
gyration_tensor_eigvec_to_smaller_eigval::SVector{2,Float64},
|
||||||
|
gyration_tensor_eigvec_to_bigger_eigval::SVector{2,Float64},
|
||||||
|
goal_gyration_tensor_eigvals_ratio::Float64,
|
||||||
|
)
|
||||||
|
cx′ = dot(particle.c, gyration_tensor_eigvec_to_bigger_eigval)
|
||||||
|
cy′ = dot(particle.c, gyration_tensor_eigvec_to_smaller_eigval)
|
||||||
|
|
||||||
|
return cx′^2 + (cy′ / goal_gyration_tensor_eigvals_ratio)^2
|
||||||
|
end
|
||||||
|
|
||||||
end # module
|
end # module
|
Loading…
Reference in a new issue