diff --git a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl index 99ac5b8..25a5c8c 100644 --- a/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl +++ b/src/RL/Envs/LocalCOMWithAdditionalShapeRewardEnv.jl @@ -114,8 +114,8 @@ function state_update_hook!( distance_to_local_center_of_mass_sum = 0.0 - for id in 1:n_particles - n_neighbours = env_helper.n_neighbours[id] + for particle_id in 1:n_particles + n_neighbours = env_helper.n_neighbours[particle_id] if n_neighbours == 0 state_id = env.shared.n_states @@ -124,20 +124,14 @@ function state_update_hook!( env_helper.max_distance_to_local_center_of_mass else vec_to_local_center_of_mass = - env_helper.vec_to_neighbour_sums[id] / n_neighbours - + env_helper.vec_to_neighbour_sums[particle_id] / n_neighbours distance = ReCo.norm2d(vec_to_local_center_of_mass) - - env_helper.distances_to_local_center_of_mass[id] = distance - + env_helper.distances_to_local_center_of_mass[particle_id] = distance distance_to_local_center_of_mass_sum += distance - distance_state_ind = find_state_ind(distance, env.distance_state_space) - si, co = sincos(particles[id].φ) - + si, co = sincos(particles[particle_id].φ) direction_angle = ReCo.angle2(SVector(co, si), vec_to_local_center_of_mass) - direction_state_ind = find_state_ind( direction_angle, env.direction_angle_state_space ) @@ -145,7 +139,7 @@ function state_update_hook!( state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind] end - env_helper.shared.states_id[id] = state_id + env_helper.shared.states_id[particle_id] = state_id end mean_distance_to_local_center_of_mass = @@ -171,16 +165,14 @@ function update_reward!( env_helper::LocalCOMWithAdditionalShapeRewardEnvHelper, particle::ReCo.Particle, ) - id = particle.id - normalization = env_helper.shared.n_particles # TODO: Add factor from steps - n_neighbours = env_helper.n_neighbours[id] + n_neighbours = env_helper.n_neighbours[particle.id] if n_neighbours == 0 env.shared.reward = 0.0 else reward = minimizing_reward( - env_helper.distances_to_local_center_of_mass[id], + env_helper.distances_to_local_center_of_mass[particle.id], env_helper.max_distance_to_local_center_of_mass, ) diff --git a/src/RL/Envs/OriginEnv.jl b/src/RL/Envs/OriginEnv.jl new file mode 100644 index 0000000..1b92e5c --- /dev/null +++ b/src/RL/Envs/OriginEnv.jl @@ -0,0 +1,105 @@ +export OriginEnv + +using ..ReCo: ReCo + +struct OriginEnv <: Env + shared::EnvSharedProps + + distance_state_space::Vector{Interval} + direction_angle_state_space::Vector{Interval} + + function OriginEnv(; + n_distance_states::Int64=4, n_direction_angle_states::Int64=3, args + ) + @assert n_distance_states > 1 + @assert n_direction_angle_states > 1 + + direction_angle_state_space = gen_angle_state_space(n_direction_angle_states) + + min_distance = 0.0 + max_distance = sqrt(2) * args.half_box_len + + distance_state_space = gen_distance_state_space( + min_distance, max_distance, n_distance_states + ) + + n_states = n_distance_states * n_direction_angle_states + + shared = EnvSharedProps(n_states, (n_distance_states, n_direction_angle_states)) + + return new(shared, distance_state_space, direction_angle_state_space) + end +end + +mutable struct OriginEnvHelper <: EnvHelper + shared::EnvHelperSharedProps + + distances_to_origin::Vector{Float64} + max_distance_to_origin::Float64 + + half_box_len::Float64 + + function OriginEnvHelper(shared::EnvHelperSharedProps, half_box_len::Float64) + max_distance_to_origin = sqrt(2) * half_box_len + + return new( + shared, + fill(SVector(0.0, 0.0), shared.n_particles), + max_distance_to_origin, + half_box_len, + ) + end +end + +function gen_env_helper(::OriginEnv, env_helper_shared::EnvHelperSharedProps; args) + return OriginEnvHelper(env_helper_shared, args.half_box_len) +end + +function pre_integration_hook!(::OriginEnvHelper) + return nothing +end + +function state_update_helper_hook!( + ::OriginEnvHelper, id1::Int64, id2::Int64, r⃗₁₂::SVector{2,Float64} +) + return nothing +end + +function state_update_hook!(env_helper::OriginEnvHelper, particles::Vector{ReCo.Particle}) + n_particles = env_helper.shared.n_particles + + env = env_helper.shared.env + + for particle_id in 1:n_particles + vec_to_origin = -particles[particle_id].c + distance_to_origin = ReCo.norm2d(vec_to_origin) + env_helper.distances_to_origin[particle_id] = distance_to_origin + distance_state_ind = find_state_ind(distance_to_origin, env.distance_state_space) + + si, co = sincos(particles[particle_id].φ) + direction_angle = ReCo.angle2(SVector(co, si), vec_to_origin) + direction_state_ind = find_state_ind( + direction_angle, env.direction_angle_state_space + ) + + state_id = env.shared.state_id_tensor[distance_state_ind, direction_state_ind] + + env_helper.shared.states_id[particle_id] = state_id + end + + return nothing +end + +function update_reward!( + env::OriginEnv, env_helper::OriginEnvHelper, particle::ReCo.Particle +) + normalization = env_helper.shared.n_particles # TODO: Add factor from steps + + reward = minimizing_reward( + env_helper.distances_to_origin[particle.id], env_helper.max_distance_to_origin + ) + + env.shared.reward = reward / normalization + + return nothing +end \ No newline at end of file diff --git a/src/RL/RL.jl b/src/RL/RL.jl index 5570df1..25ba548 100644 --- a/src/RL/RL.jl +++ b/src/RL/RL.jl @@ -1,6 +1,6 @@ module RL -export run_rl, LocalCOMWithAdditionalShapeRewardEnv +export run_rl, LocalCOMWithAdditionalShapeRewardEnv, OriginEnv using Base: OneTo @@ -141,5 +141,6 @@ function run_rl(; end include("Envs/LocalCOMWithAdditionalShapeRewardEnv.jl") +include("Envs/OriginEnv.jl") end # module \ No newline at end of file