Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/pybdsim/Beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def _MakeGaussTwiss(self):
setattr(self, 'SetDispYP', self._SetDispYP)

def SetDistributionType(self,distrType='reference'):
if distrType not in BDSIMDistributionTypes:
if distrType not in BDSIMDistributionTypes and not 'bdsimsampler:' in distrType:
raise ValueError("Unknown distribution type: '"+str(distrType)+"'")
self['distrType'] = '"' + distrType + '"'
self._UpdateMemberFunctions(distrType)
Expand Down
11 changes: 6 additions & 5 deletions src/pybdsim/Builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,14 +1325,15 @@ def Insert(self, newElement, index = 0, after = False, substitute = False):
if newElement.name in self.elements.keys():
if not substitute:
raise ValueError(f"New element {newElement.name} already exists in elements. If you want to overwrite it, set substitute=True")
else:
self.elements[newElement.name] = newElement
self.elements[newElement.name] = newElement

# Modify sequence
self.sequence.insert(index, newElement.name)
elif isinstance(newElement, str):
if newElement not in self.elements.keys():
raise ValueError(f"New element {newElement} not found in elements")

# Modify sequence
self.sequence.insert(index, newElement.name)
# Modify sequence
self.sequence.insert(index, newElement)

def InsertAndReplace(self, newElement, sLocation = 0, element_name = None):
"""
Expand Down
17 changes: 17 additions & 0 deletions tests/test_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -257,6 +257,23 @@ def test_Laser_repr():
expected = 'myl: laser, l=0.1, waveLength=5370, x=0.2, y=0.3, z=0.4;\n'
assert repr(laser) == expected

def test_Insert():
machine = pybdsim.Builder.Machine()
machine.AddDrift(name="dr", length=0.1)
machine.Insert(pybdsim.Builder.RBend("myr", 0.1, angle=0), index = "dr")
machine.Insert(pybdsim.Builder.RBend("myr1", 0.1, angle=0), index = "dr", after=True)
machine.Insert("dr", index = "myr1", after=True, substitute=True)
expected = ['myr', 'dr', 'myr1', 'dr']
assert machine.sequence == expected

def test_bdsimsampler_repr():
beam = pybdsim.Beam.Beam()
beam.SetDistributionType("bdsimsampler:SAMPLER")
beam.SetEnergy(100)
beam.SetParticleType("proton")
expected = 'beam,\tdistrType="bdsimsampler:SAMPLER",\n\tenergy=100*GeV, \n\tparticle="proton";'
assert repr(beam) == expected

#def test_element_split_drift():
# c = pybdsim.Builder.Element('d1', 'drift', l=(0.4, 'm'), aper1=(2, 'cm'))
# b = c/2
Expand Down