diff --git a/src/pybdsim/Beam.py b/src/pybdsim/Beam.py index 7010bdd..538de2f 100644 --- a/src/pybdsim/Beam.py +++ b/src/pybdsim/Beam.py @@ -111,7 +111,7 @@ def _MakeGaussTwiss(self): setattr(self, 'SetDispYP', self._SetDispYP) def SetDistributionType(self,distrType='reference'): - if distrType not in BDSIMDistributionTypes: + if distrType not in BDSIMDistributionTypes and not 'bdsimsampler:' in distrType and not 'eventgeneratorfile:' in distrType: raise ValueError("Unknown distribution type: '"+str(distrType)+"'") self['distrType'] = '"' + distrType + '"' self._UpdateMemberFunctions(distrType) diff --git a/src/pybdsim/Builder.py b/src/pybdsim/Builder.py index fb61e3a..93d786c 100644 --- a/src/pybdsim/Builder.py +++ b/src/pybdsim/Builder.py @@ -1325,14 +1325,15 @@ def Insert(self, newElement, index = 0, after = False, substitute = False): if newElement.name in self.elements.keys(): if not substitute: raise ValueError(f"New element {newElement.name} already exists in elements. If you want to overwrite it, set substitute=True") - else: - self.elements[newElement.name] = newElement + self.elements[newElement.name] = newElement + + # Modify sequence + self.sequence.insert(index, newElement.name) elif isinstance(newElement, str): if newElement not in self.elements.keys(): raise ValueError(f"New element {newElement} not found in elements") - - # Modify sequence - self.sequence.insert(index, newElement.name) + # Modify sequence + self.sequence.insert(index, newElement) def InsertAndReplace(self, newElement, sLocation = 0, element_name = None): """ diff --git a/tests/test_builder.py b/tests/test_builder.py index 9d8a2e9..1104012 100644 --- a/tests/test_builder.py +++ b/tests/test_builder.py @@ -257,6 +257,31 @@ def test_Laser_repr(): expected = 'myl: laser, l=0.1, waveLength=5370, x=0.2, y=0.3, z=0.4;\n' assert repr(laser) == expected +def test_Insert(): + machine = pybdsim.Builder.Machine() + machine.AddDrift(name="dr", length=0.1) + machine.Insert(pybdsim.Builder.RBend("myr", 0.1, angle=0), index = "dr") + machine.Insert(pybdsim.Builder.RBend("myr1", 0.1, angle=0), index = "dr", after=True) + machine.Insert("dr", index = "myr1", after=True, substitute=True) + expected = ['myr', 'dr', 'myr1', 'dr'] + assert machine.sequence == expected + +def test_bdsimsampler_repr(): + beam = pybdsim.Beam.Beam() + beam.SetDistributionType("bdsimsampler:SAMPLER") + beam.SetEnergy(100) + beam.SetParticleType("proton") + expected = 'beam,\tdistrType="bdsimsampler:SAMPLER",\n\tenergy=100*GeV, \n\tparticle="proton";' + assert repr(beam) == expected + +def test_eventgeneratorfile_repr(): + beam = pybdsim.Beam.Beam() + beam.SetDistributionType("eventgeneratorfile:FORMAT") + beam.SetEnergy(100) + beam.SetParticleType("proton") + expected = 'beam,\tdistrType="eventgeneratorfile:FORMAT",\n\tenergy=100*GeV, \n\tparticle="proton";' + assert repr(beam) == expected + #def test_element_split_drift(): # c = pybdsim.Builder.Element('d1', 'drift', l=(0.4, 'm'), aper1=(2, 'cm')) # b = c/2