mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added tests
This commit is contained in:
parent
8ad67229a8
commit
7cb7f2d619
6 changed files with 79 additions and 9 deletions
|
@ -15,7 +15,6 @@ GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
|
||||||
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
|
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
|
||||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||||
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
|
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
|
||||||
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
|
|
||||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||||
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
|
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
|
||||||
|
@ -29,4 +28,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||||
|
|
||||||
[compat]
|
[compat]
|
||||||
julia = "1.6"
|
julia = ">=1.7"
|
||||||
|
|
|
@ -27,14 +27,14 @@ function restrict_coordinate(value::Float64, half_box_len::Float64)
|
||||||
return value
|
return value
|
||||||
end
|
end
|
||||||
|
|
||||||
function restrict_coordinates!(v::SVector{2,Float64}, half_box_len::Float64)
|
function restrict_coordinates(v::SVector{2,Float64}, half_box_len::Float64)
|
||||||
return SVector(
|
return SVector(
|
||||||
restrict_coordinate(v[1], half_box_len), restrict_coordinate(v[2], half_box_len)
|
restrict_coordinate(v[1], half_box_len), restrict_coordinate(v[2], half_box_len)
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
function restrict_coordinates!(p::Particle, half_box_len::Float64)
|
function restrict_coordinates!(p::Particle, half_box_len::Float64)
|
||||||
p.tmp_c = restrict_coordinates!(p.tmp_c, half_box_len)
|
p.tmp_c = restrict_coordinates(p.tmp_c, half_box_len)
|
||||||
|
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
module RL
|
module RL
|
||||||
|
|
||||||
export run
|
export run_rl
|
||||||
|
|
||||||
using ReinforcementLearning
|
using ReinforcementLearning
|
||||||
using Flux: InvDecay
|
using Flux: InvDecay
|
||||||
|
@ -12,8 +12,6 @@ using ProgressMeter: @showprogress
|
||||||
|
|
||||||
using ..ReCo
|
using ..ReCo
|
||||||
|
|
||||||
import Base: run
|
|
||||||
|
|
||||||
const INITIAL_REWARD = 0.0
|
const INITIAL_REWARD = 0.0
|
||||||
|
|
||||||
struct DistanceState{L<:Bound}
|
struct DistanceState{L<:Bound}
|
||||||
|
@ -369,7 +367,7 @@ function post_integration_hook(
|
||||||
return nothing
|
return nothing
|
||||||
end
|
end
|
||||||
|
|
||||||
function run(;
|
function run_rl(;
|
||||||
goal_shape_ratio::Float64,
|
goal_shape_ratio::Float64,
|
||||||
n_episodes::Int64=100,
|
n_episodes::Int64=100,
|
||||||
episode_duration::Float64=50.0,
|
episode_duration::Float64=50.0,
|
||||||
|
|
|
@ -52,7 +52,7 @@ function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64)
|
||||||
S22 = 0.0
|
S22 = 0.0
|
||||||
|
|
||||||
for p in particles
|
for p in particles
|
||||||
shifted_c = restrict_coordinates!(p.c - COM, half_box_len)
|
shifted_c = restrict_coordinates(p.c - COM, half_box_len)
|
||||||
|
|
||||||
S11 += shifted_c[1]^2
|
S11 += shifted_c[1]^2
|
||||||
S12 += shifted_c[1] * shifted_c[2]
|
S12 += shifted_c[1] * shifted_c[2]
|
||||||
|
|
3
test/Project.toml
Normal file
3
test/Project.toml
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
[deps]
|
||||||
|
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||||
|
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
70
test/runtests.jl
Normal file
70
test/runtests.jl
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
using Test
|
||||||
|
using ReCo
|
||||||
|
|
||||||
|
using StaticArrays: SVector
|
||||||
|
|
||||||
|
@testset "Particle.jl" begin
|
||||||
|
half_box_len = 1.0
|
||||||
|
|
||||||
|
@testset "restrict_coordinates" begin
|
||||||
|
@test ReCo.restrict_coordinates(SVector(1.5, -0.8), half_box_len) ≈
|
||||||
|
SVector(-0.5, -0.8)
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "are_overlapping" begin
|
||||||
|
overlapping_r2 = 1.0
|
||||||
|
|
||||||
|
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
|
||||||
|
SVector(-0.4, 0.0), SVector(0.4, 0.0), overlapping_r2, half_box_len
|
||||||
|
)
|
||||||
|
@test overlapping == true
|
||||||
|
@test r⃗₁₂ ≈ SVector(0.8, 0.0)
|
||||||
|
@test distance2 ≈ 0.8^2
|
||||||
|
|
||||||
|
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
|
||||||
|
SVector(-0.6, 0.0), SVector(0.6, 0.0), overlapping_r2, half_box_len
|
||||||
|
)
|
||||||
|
@test overlapping == true
|
||||||
|
@test r⃗₁₂ ≈ SVector(-0.8, 0.0)
|
||||||
|
@test distance2 ≈ 0.8^2
|
||||||
|
|
||||||
|
overlapping_r2 = 0.5^2
|
||||||
|
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
|
||||||
|
SVector(-0.3, 0.0), SVector(0.3, 0.0), overlapping_r2, half_box_len
|
||||||
|
)
|
||||||
|
@test overlapping == false
|
||||||
|
@test r⃗₁₂ ≈ SVector(0.6, 0.0)
|
||||||
|
@test distance2 ≈ 0.6^2
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "shape.jl" begin
|
||||||
|
n_particles = 10
|
||||||
|
v₀ = 0.0
|
||||||
|
sim_consts = ReCo.gen_sim_consts(n_particles, v₀)
|
||||||
|
|
||||||
|
@testset "gen_sim_consts" begin
|
||||||
|
@test sim_consts.n_particles == 16
|
||||||
|
end
|
||||||
|
|
||||||
|
half_box_len = sim_consts.half_box_len
|
||||||
|
|
||||||
|
@testset "project_to_unit_circle" begin
|
||||||
|
@test ReCo.project_to_unit_circle(0.0, half_box_len) ≈ SVector(-1.0, 0.0)
|
||||||
|
@test ReCo.project_to_unit_circle(half_box_len, half_box_len) ≈
|
||||||
|
ReCo.project_to_unit_circle(-half_box_len, half_box_len) ≈
|
||||||
|
SVector(1.0, 0.0)
|
||||||
|
end
|
||||||
|
|
||||||
|
particles = ReCo.gen_particles(
|
||||||
|
sim_consts.grid_n, sim_consts.grid_box_width, half_box_len
|
||||||
|
)
|
||||||
|
|
||||||
|
@testset "center_of_mass" begin
|
||||||
|
@test ReCo.center_of_mass(particles, half_box_len) ≈ SVector(0.0, 0.0)
|
||||||
|
end
|
||||||
|
|
||||||
|
@testset "gyration_tensor" begin
|
||||||
|
@test ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) == 1.0
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in a new issue