mirror of
https://gitlab.rlp.net/mobitar/ReCo.jl.git
synced 2024-12-21 00:51:21 +00:00
Added tests
This commit is contained in:
parent
8ad67229a8
commit
7cb7f2d619
6 changed files with 79 additions and 9 deletions
|
@ -15,7 +15,6 @@ GLMakie = "e9467ef8-e4e7-5192-8a1a-b1aee30e663a"
|
|||
Intervals = "d8418881-c3e1-53bb-8760-2df7ec849ed5"
|
||||
JLD2 = "033835bb-8acc-5ee8-8aae-3f567f8a3819"
|
||||
JSON3 = "0f8b85d8-7281-11e9-16c2-39a750bddbf1"
|
||||
JuliaFormatter = "98e50ef6-434e-11e9-1051-2b60c6c9e899"
|
||||
LaTeXStrings = "b964fa9f-0449-5b57-a5c2-d3ea65f4040f"
|
||||
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
|
||||
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
|
||||
|
@ -29,4 +28,4 @@ StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
|||
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
|
||||
|
||||
[compat]
|
||||
julia = "1.6"
|
||||
julia = ">=1.7"
|
||||
|
|
|
@ -27,14 +27,14 @@ function restrict_coordinate(value::Float64, half_box_len::Float64)
|
|||
return value
|
||||
end
|
||||
|
||||
function restrict_coordinates!(v::SVector{2,Float64}, half_box_len::Float64)
|
||||
function restrict_coordinates(v::SVector{2,Float64}, half_box_len::Float64)
|
||||
return SVector(
|
||||
restrict_coordinate(v[1], half_box_len), restrict_coordinate(v[2], half_box_len)
|
||||
)
|
||||
end
|
||||
|
||||
function restrict_coordinates!(p::Particle, half_box_len::Float64)
|
||||
p.tmp_c = restrict_coordinates!(p.tmp_c, half_box_len)
|
||||
p.tmp_c = restrict_coordinates(p.tmp_c, half_box_len)
|
||||
|
||||
return nothing
|
||||
end
|
||||
|
|
|
@ -1,6 +1,6 @@
|
|||
module RL
|
||||
|
||||
export run
|
||||
export run_rl
|
||||
|
||||
using ReinforcementLearning
|
||||
using Flux: InvDecay
|
||||
|
@ -12,8 +12,6 @@ using ProgressMeter: @showprogress
|
|||
|
||||
using ..ReCo
|
||||
|
||||
import Base: run
|
||||
|
||||
const INITIAL_REWARD = 0.0
|
||||
|
||||
struct DistanceState{L<:Bound}
|
||||
|
@ -369,7 +367,7 @@ function post_integration_hook(
|
|||
return nothing
|
||||
end
|
||||
|
||||
function run(;
|
||||
function run_rl(;
|
||||
goal_shape_ratio::Float64,
|
||||
n_episodes::Int64=100,
|
||||
episode_duration::Float64=50.0,
|
||||
|
|
|
@ -52,7 +52,7 @@ function gyration_tensor(particles::Vector{Particle}, half_box_len::Float64)
|
|||
S22 = 0.0
|
||||
|
||||
for p in particles
|
||||
shifted_c = restrict_coordinates!(p.c - COM, half_box_len)
|
||||
shifted_c = restrict_coordinates(p.c - COM, half_box_len)
|
||||
|
||||
S11 += shifted_c[1]^2
|
||||
S12 += shifted_c[1] * shifted_c[2]
|
||||
|
|
3
test/Project.toml
Normal file
3
test/Project.toml
Normal file
|
@ -0,0 +1,3 @@
|
|||
[deps]
|
||||
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
|
||||
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
|
70
test/runtests.jl
Normal file
70
test/runtests.jl
Normal file
|
@ -0,0 +1,70 @@
|
|||
using Test
|
||||
using ReCo
|
||||
|
||||
using StaticArrays: SVector
|
||||
|
||||
@testset "Particle.jl" begin
|
||||
half_box_len = 1.0
|
||||
|
||||
@testset "restrict_coordinates" begin
|
||||
@test ReCo.restrict_coordinates(SVector(1.5, -0.8), half_box_len) ≈
|
||||
SVector(-0.5, -0.8)
|
||||
end
|
||||
|
||||
@testset "are_overlapping" begin
|
||||
overlapping_r2 = 1.0
|
||||
|
||||
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
|
||||
SVector(-0.4, 0.0), SVector(0.4, 0.0), overlapping_r2, half_box_len
|
||||
)
|
||||
@test overlapping == true
|
||||
@test r⃗₁₂ ≈ SVector(0.8, 0.0)
|
||||
@test distance2 ≈ 0.8^2
|
||||
|
||||
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
|
||||
SVector(-0.6, 0.0), SVector(0.6, 0.0), overlapping_r2, half_box_len
|
||||
)
|
||||
@test overlapping == true
|
||||
@test r⃗₁₂ ≈ SVector(-0.8, 0.0)
|
||||
@test distance2 ≈ 0.8^2
|
||||
|
||||
overlapping_r2 = 0.5^2
|
||||
overlapping, r⃗₁₂, distance2 = ReCo.are_overlapping(
|
||||
SVector(-0.3, 0.0), SVector(0.3, 0.0), overlapping_r2, half_box_len
|
||||
)
|
||||
@test overlapping == false
|
||||
@test r⃗₁₂ ≈ SVector(0.6, 0.0)
|
||||
@test distance2 ≈ 0.6^2
|
||||
end
|
||||
end
|
||||
|
||||
@testset "shape.jl" begin
|
||||
n_particles = 10
|
||||
v₀ = 0.0
|
||||
sim_consts = ReCo.gen_sim_consts(n_particles, v₀)
|
||||
|
||||
@testset "gen_sim_consts" begin
|
||||
@test sim_consts.n_particles == 16
|
||||
end
|
||||
|
||||
half_box_len = sim_consts.half_box_len
|
||||
|
||||
@testset "project_to_unit_circle" begin
|
||||
@test ReCo.project_to_unit_circle(0.0, half_box_len) ≈ SVector(-1.0, 0.0)
|
||||
@test ReCo.project_to_unit_circle(half_box_len, half_box_len) ≈
|
||||
ReCo.project_to_unit_circle(-half_box_len, half_box_len) ≈
|
||||
SVector(1.0, 0.0)
|
||||
end
|
||||
|
||||
particles = ReCo.gen_particles(
|
||||
sim_consts.grid_n, sim_consts.grid_box_width, half_box_len
|
||||
)
|
||||
|
||||
@testset "center_of_mass" begin
|
||||
@test ReCo.center_of_mass(particles, half_box_len) ≈ SVector(0.0, 0.0)
|
||||
end
|
||||
|
||||
@testset "gyration_tensor" begin
|
||||
@test ReCo.gyration_tensor_eigvals_ratio(particles, half_box_len) == 1.0
|
||||
end
|
||||
end
|
Loading…
Reference in a new issue